[tf.data] Memory-safe implementation of sharing access to the memory cache.

PiperOrigin-RevId: 307736215
Change-Id: If10ef65e6706a106e6bb4fc2d6fe4542bbe056cc
This commit is contained in:
Jiri Simsa 2020-04-21 20:43:40 -07:00 committed by TensorFlower Gardener
parent df7e7b1617
commit b546b463b5
3 changed files with 85 additions and 75 deletions

View File

@ -694,8 +694,10 @@ Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
public:
explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache)
: DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
std::shared_ptr<MemoryCache> cache)
: DatasetBase(DatasetContext(ctx)),
input_(input),
cache_(std::move(cache)) {
input_->Ref();
}
@ -708,7 +710,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
return absl::make_unique<MemoryIterator>(
MemoryIterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
cache_);
cache_.get());
}
const DataTypeVector& output_dtypes() const override {
@ -964,7 +966,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
}; // MemoryIterator
const DatasetBase* const input_;
MemoryCache* const cache_;
const std::shared_ptr<MemoryCache> cache_;
}; // MemoryDatasetBase
// This version of memory dataset has an exclusive ownership of the memory cache
@ -973,22 +975,19 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
public:
MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache, const ResourceHandle& resource_handle)
: MemoryDatasetBase(ctx, input, cache),
resource_handle_(resource_handle) {
cleanup_ = [this, mgr = ctx->resource_manager()]() {
DCHECK(cache_->RefCountIsOne());
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
resource_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
};
}
MemoryCacheManager* manager, ResourceHandle&& resource_handle)
: MemoryDatasetBase(ctx, input, manager->get()),
manager_(manager),
resource_handle_(std::move(resource_handle)),
resource_mgr_(ctx->resource_manager()) {}
~MemoryDataset() override {
cache_->Unref();
cleanup_();
manager_->Unref();
Status s = resource_mgr_->Delete<MemoryCacheManager>(
resource_handle_.container(), resource_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
}
protected:
@ -1005,8 +1004,9 @@ class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
}
private:
std::function<void()> cleanup_;
MemoryCacheManager* const manager_; // Owned.
const ResourceHandle resource_handle_;
ResourceMgr* const resource_mgr_; // Not owned.
};
// This version of memory dataset has a shared ownership of the memory cache
@ -1016,28 +1016,23 @@ class CacheDatasetOp::MemoryDatasetV2
: public CacheDatasetOp::MemoryDatasetBase {
public:
MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache, const ResourceHandle& resource_handle)
: MemoryDatasetBase(ctx, input, cache),
resource_handle_(std::move(resource_handle)) {
cleanup_ = [this, mgr = ctx->resource_manager()]() {
if (cache_->RefCountIsOne()) {
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
resource_handle_.name());
if (!s.ok()) {
if (errors::IsNotFound(s)) {
// This is a bening race resulting from concurrent deletion.
VLOG(1) << "Failed to delete cache resource: " << s.ToString();
} else {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
}
}
};
}
MemoryCacheManager* manager, ResourceHandle&& resource_handle,
bool owns_resource)
: MemoryDatasetBase(ctx, input, manager->get()),
manager_(manager),
owns_resource_(owns_resource),
resource_handle_(std::move(resource_handle)),
resource_mgr_(ctx->resource_manager()) {}
~MemoryDatasetV2() override {
cache_->Unref();
cleanup_();
manager_->Unref();
if (owns_resource_) {
Status s = resource_mgr_->Delete<MemoryCacheManager>(
resource_handle_.container(), resource_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
}
}
protected:
@ -1058,8 +1053,10 @@ class CacheDatasetOp::MemoryDatasetV2
}
private:
std::function<void()> cleanup_;
MemoryCacheManager* const manager_; // Owned.
const bool owns_resource_;
const ResourceHandle resource_handle_;
ResourceMgr* const resource_mgr_; // Not owned.
};
CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
@ -1077,33 +1074,39 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
resource_id_counter.fetch_add(1));
if (op_version_ == 2) {
MemoryCache* cache = nullptr;
bool owns_resource = false;
MemoryCacheManager* manager = nullptr;
auto handle = HandleFromInput(ctx, 2);
Status s = ctx->resource_manager()->Lookup<MemoryCache>(
handle.container(), handle.name(), &cache);
Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
handle.container(), handle.name(), &manager);
if (errors::IsNotFound(s)) {
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->LookupOrCreate<MemoryCache>(
container, name, &cache, [](MemoryCache** cache) {
*cache = new MemoryCache();
return Status::OK();
}));
handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
owns_resource = true;
OP_REQUIRES_OK(
ctx,
ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
container, name, &manager, [](MemoryCacheManager** manager) {
*manager = new MemoryCacheManager();
return Status::OK();
}));
handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
} else {
OP_REQUIRES_OK(ctx, s);
}
// Ownership of cache is transferred onto `MemoryDatasetV2`.
*output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
// Ownership of manager is transferred onto `MemoryDatasetV2`.
*output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
owns_resource);
} else {
MemoryCache* cache;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate<MemoryCache>(
container, name, &cache, [](MemoryCache** cache) {
*cache = new MemoryCache();
return Status::OK();
}));
auto handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
// Ownership of cache is transferred onto `MemoryDataset`.
*output = new MemoryDataset(ctx, input, cache, handle);
MemoryCacheManager* manager;
OP_REQUIRES_OK(
ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
container, name, &manager, [](MemoryCacheManager** manager) {
*manager = new MemoryCacheManager();
return Status::OK();
}));
auto handle =
MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
// Ownership of manager is transferred onto `MemoryDataset`.
*output = new MemoryDataset(ctx, input, manager, std::move(handle));
}
} else {
if (op_version_ == 2) {

View File

@ -31,7 +31,7 @@ constexpr char kMemoryCache[] = "MemoryCache";
} // namespace
string MemoryCache::DebugString() const { return kMemoryCache; }
string MemoryCacheManager::DebugString() const { return kMemoryCache; }
void MemoryCache::Complete(std::vector<std::vector<Tensor>>&& cache) {
mutex_lock l(mu_);
@ -65,19 +65,15 @@ size_t MemoryCache::size() {
AnonymousMemoryCacheHandleOp::AnonymousMemoryCacheHandleOp(
OpKernelConstruction* ctx)
: AnonymousResourceOp<MemoryCache>(ctx) {}
void AnonymousMemoryCacheHandleOp::Compute(OpKernelContext* ctx) {
AnonymousResourceOp<MemoryCache>::Compute(ctx);
}
: AnonymousResourceOp<MemoryCacheManager>(ctx) {}
string AnonymousMemoryCacheHandleOp::name() { return kMemoryCache; }
Status AnonymousMemoryCacheHandleOp::CreateResource(
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib, MemoryCache** resource) {
*resource = new MemoryCache();
FunctionLibraryRuntime* lib, MemoryCacheManager** manager) {
*manager = new MemoryCacheManager();
return Status::OK();
}

View File

@ -27,12 +27,10 @@ namespace data {
// The expected use is that a single `MemoryWriterIterator` populates the
// cache with dataset elements. Once all elements are cached, the cache can
// be used by one or more `MemoryReaderIterator`s.
class MemoryCache : public ResourceBase {
class MemoryCache {
public:
MemoryCache() = default;
string DebugString() const override;
// Marks the cache as completed.
void Complete(std::vector<std::vector<Tensor>>&& cache);
@ -55,11 +53,24 @@ class MemoryCache : public ResourceBase {
std::vector<std::vector<Tensor>> cache_ TF_GUARDED_BY(mu_);
};
// A resource wrapping a shared instance of a memory cache.
class MemoryCacheManager : public ResourceBase {
public:
MemoryCacheManager() : cache_(std::make_shared<MemoryCache>()) {}
string DebugString() const override;
std::shared_ptr<MemoryCache> get() { return cache_; }
private:
std::shared_ptr<MemoryCache> cache_;
};
// Creates an instance of cache resource and transfers ownership to the caller.
class AnonymousMemoryCacheHandleOp : public AnonymousResourceOp<MemoryCache> {
class AnonymousMemoryCacheHandleOp
: public AnonymousResourceOp<MemoryCacheManager> {
public:
explicit AnonymousMemoryCacheHandleOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
string name() override;
@ -67,7 +78,7 @@ class AnonymousMemoryCacheHandleOp : public AnonymousResourceOp<MemoryCache> {
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib,
MemoryCache** resource) override;
MemoryCacheManager** manager) override;
};
// Deletes an instance of cache resource.