[tf.data] This CL changes how the in-memory cache resource is managed, making it possible for the cache dataset to support both a) sharing of the cache across iterators and b) serialization. As a consequence, this CL enables sharing of the cache across iterators for tf.distribute and tf.data service use cases (which require serialization support).
PiperOrigin-RevId: 307469217 Change-Id: Ia4f4384752609b83cb10078ebeb20dfc6c8a2d8f
This commit is contained in:
parent
b592f87bd8
commit
7ebbab819e
|
@ -0,0 +1,4 @@
|
||||||
|
op {
|
||||||
|
graph_op_name: "DummyMemoryCache"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
|
@ -27,8 +27,6 @@ namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kCacheDataset[] = "CacheDataset";
|
|
||||||
constexpr char kCacheDatasetV2[] = "CacheDatasetV2";
|
|
||||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||||
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||||
|
@ -55,10 +53,6 @@ Status MakeStateless::OptimizeAndCollectStats(Cluster* cluster,
|
||||||
node.add_input(zero_node->name());
|
node.add_input(zero_node->name());
|
||||||
// set `reshuffle_each_iteration` attr
|
// set `reshuffle_each_iteration` attr
|
||||||
(*node.mutable_attr())[kReshuffleEachIteration].set_b(true);
|
(*node.mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||||
} else if (node.op() == kCacheDatasetV2) {
|
|
||||||
*node.mutable_op() = kCacheDataset;
|
|
||||||
// remove `cache` input
|
|
||||||
node.mutable_input()->RemoveLast();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,29 +28,6 @@ namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(MakeStateless, Cache) {
|
|
||||||
using test::function::NDef;
|
|
||||||
GrapplerItem item;
|
|
||||||
item.graph = test::function::GDef(
|
|
||||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
|
||||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
|
||||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
|
||||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
|
||||||
NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_INT64}}),
|
|
||||||
NDef("handle", "Const", {}, {{"value", 1}, {"dtype", DT_RESOURCE}}),
|
|
||||||
graph_tests_utils::MakeCacheV2Node("cache", "range", "filename",
|
|
||||||
"handle")},
|
|
||||||
{});
|
|
||||||
|
|
||||||
MakeStateless optimizer;
|
|
||||||
GraphDef output;
|
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
|
||||||
EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("cache", output));
|
|
||||||
int index = graph_utils::FindGraphNodeWithName("cache", output);
|
|
||||||
EXPECT_EQ(output.node(index).op(), "CacheDataset");
|
|
||||||
EXPECT_EQ(output.node(index).input_size(), 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MakeStateless, Shuffle) {
|
TEST(MakeStateless, Shuffle) {
|
||||||
using test::function::NDef;
|
using test::function::NDef;
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
|
|
|
@ -48,7 +48,6 @@ constexpr char kShardId[] = "shard_id";
|
||||||
constexpr char kCreatedAt[] = "Created at";
|
constexpr char kCreatedAt[] = "Created at";
|
||||||
constexpr char kMemoryDatasetPrefix[] = "Memory";
|
constexpr char kMemoryDatasetPrefix[] = "Memory";
|
||||||
constexpr char kMemoryCache[] = "MemoryCache";
|
constexpr char kMemoryCache[] = "MemoryCache";
|
||||||
constexpr char kTFData[] = "tf_data";
|
|
||||||
constexpr char kCacheClaimed[] = "cache_claimed";
|
constexpr char kCacheClaimed[] = "cache_claimed";
|
||||||
constexpr char kCacheSize[] = "cache_size";
|
constexpr char kCacheSize[] = "cache_size";
|
||||||
constexpr char kCache[] = "cache";
|
constexpr char kCache[] = "cache";
|
||||||
|
@ -58,10 +57,10 @@ constexpr char kIndex[] = "index";
|
||||||
constexpr char kImpl[] = "Impl";
|
constexpr char kImpl[] = "Impl";
|
||||||
constexpr char kCacheDataset[] = "CacheDataset";
|
constexpr char kCacheDataset[] = "CacheDataset";
|
||||||
|
|
||||||
class CacheDatasetOp::FileDataset : public DatasetBase {
|
class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input,
|
FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
string filename, Env* env)
|
string filename, Env* env)
|
||||||
: DatasetBase(DatasetContext(ctx)),
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
input_(input),
|
input_(input),
|
||||||
filename_(std::move(filename)),
|
filename_(std::move(filename)),
|
||||||
|
@ -76,7 +75,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||||
DCHECK_EQ(item_index_padding_size_, 7);
|
DCHECK_EQ(item_index_padding_size_, 7);
|
||||||
}
|
}
|
||||||
|
|
||||||
~FileDataset() override { input_->Unref(); }
|
~FileDatasetBase() override { input_->Unref(); }
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
const string& prefix) const override {
|
const string& prefix) const override {
|
||||||
|
@ -107,17 +106,6 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
|
||||||
DatasetGraphDefBuilder* b,
|
|
||||||
Node** output) const override {
|
|
||||||
Node* input_graph = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
|
|
||||||
Node* filename = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
|
|
||||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
const DatasetBase* const input_;
|
const DatasetBase* const input_;
|
||||||
const tstring filename_;
|
const tstring filename_;
|
||||||
|
|
||||||
|
@ -131,10 +119,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||||
tensor_index);
|
tensor_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
class FileIterator : public DatasetIterator<FileDataset> {
|
class FileIterator : public DatasetIterator<FileDatasetBase> {
|
||||||
public:
|
public:
|
||||||
explicit FileIterator(const Params& params)
|
explicit FileIterator(const Params& params)
|
||||||
: DatasetIterator<FileDataset>(params) {
|
: DatasetIterator<FileDatasetBase>(params) {
|
||||||
if (params.dataset->env_
|
if (params.dataset->env_
|
||||||
->FileExists(MetaFilename(params.dataset->filename_))
|
->FileExists(MetaFilename(params.dataset->filename_))
|
||||||
.ok()) {
|
.ok()) {
|
||||||
|
@ -199,7 +187,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// FileWriterIterator passes through and caches items from the input
|
// FileWriterIterator passes through and caches items from the input
|
||||||
// FileDataset.
|
// FileDatasetBase.
|
||||||
//
|
//
|
||||||
// This iterator is used when the cache directory is not found on disk. It
|
// This iterator is used when the cache directory is not found on disk. It
|
||||||
// creates the cache directory, and passes on the underlying iterator's
|
// creates the cache directory, and passes on the underlying iterator's
|
||||||
|
@ -214,10 +202,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||||
// partial cache gets flushed to disk in files with prefix
|
// partial cache gets flushed to disk in files with prefix
|
||||||
// <filename>_<shard_id> where shard_id is unique for each checkpoint.
|
// <filename>_<shard_id> where shard_id is unique for each checkpoint.
|
||||||
// When all elements have been produced, these shards get coalesced.
|
// When all elements have been produced, these shards get coalesced.
|
||||||
class FileWriterIterator : public DatasetIterator<FileDataset> {
|
class FileWriterIterator : public DatasetIterator<FileDatasetBase> {
|
||||||
public:
|
public:
|
||||||
explicit FileWriterIterator(const Params& params)
|
explicit FileWriterIterator(const Params& params)
|
||||||
: DatasetIterator<FileDataset>(params),
|
: DatasetIterator<FileDatasetBase>(params),
|
||||||
cur_index_(0),
|
cur_index_(0),
|
||||||
shard_id_(0),
|
shard_id_(0),
|
||||||
filename_(
|
filename_(
|
||||||
|
@ -483,10 +471,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||||
bool iteration_completed_ TF_GUARDED_BY(mu_);
|
bool iteration_completed_ TF_GUARDED_BY(mu_);
|
||||||
}; // FileWriterIterator
|
}; // FileWriterIterator
|
||||||
|
|
||||||
class FileReaderIterator : public DatasetIterator<FileDataset> {
|
class FileReaderIterator : public DatasetIterator<FileDatasetBase> {
|
||||||
public:
|
public:
|
||||||
explicit FileReaderIterator(const Params& params)
|
explicit FileReaderIterator(const Params& params)
|
||||||
: DatasetIterator<FileDataset>(params),
|
: DatasetIterator<FileDatasetBase>(params),
|
||||||
cur_index_(0),
|
cur_index_(0),
|
||||||
reader_(dataset()->env_, dataset()->filename_),
|
reader_(dataset()->env_, dataset()->filename_),
|
||||||
iterator_restored_(false) {}
|
iterator_restored_(false) {}
|
||||||
|
@ -606,14 +594,31 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||||
static constexpr size_t kMaxItems = 10000000; // 10 million
|
static constexpr size_t kMaxItems = 10000000; // 10 million
|
||||||
const size_t item_index_padding_size_;
|
const size_t item_index_padding_size_;
|
||||||
const string tensor_format_string_;
|
const string tensor_format_string_;
|
||||||
}; // FileDataset
|
}; // FileDatasetBase
|
||||||
|
|
||||||
class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset {
|
class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase {
|
||||||
|
public:
|
||||||
|
using FileDatasetBase::FileDatasetBase;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const override {
|
||||||
|
Node* input_graph = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
|
||||||
|
Node* filename = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase {
|
||||||
public:
|
public:
|
||||||
explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
string filename, Env* env,
|
string filename, Env* env,
|
||||||
const Tensor& resource_handle)
|
const Tensor& resource_handle)
|
||||||
: FileDataset(ctx, input, filename, env),
|
: FileDatasetBase(ctx, input, filename, env),
|
||||||
resource_handle_(resource_handle) {}
|
resource_handle_(resource_handle) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -686,20 +691,15 @@ Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
|
explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
MemoryCache* cache)
|
MemoryCache* cache)
|
||||||
: DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
|
: DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
|
||||||
input_->Ref();
|
input_->Ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
~MemoryDataset() override {
|
~MemoryDatasetBase() override { input_->Unref(); }
|
||||||
input_->Unref();
|
|
||||||
if (cache_) {
|
|
||||||
cache_->Unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
const string& prefix) const override {
|
const string& prefix) const override {
|
||||||
|
@ -732,44 +732,13 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
class MemoryIterator : public DatasetIterator<MemoryDatasetBase> {
|
||||||
DatasetGraphDefBuilder* b,
|
|
||||||
Node** output) const override {
|
|
||||||
Node* input_node = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
|
|
||||||
Node* filename_node = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
b->AddDataset(this, {input_node, filename_node}, output));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
class MemoryIterator : public DatasetIterator<MemoryDataset> {
|
|
||||||
public:
|
public:
|
||||||
explicit MemoryIterator(const Params& params, MemoryCache* cache)
|
explicit MemoryIterator(const Params& params, MemoryCache* cache)
|
||||||
: DatasetIterator<MemoryDataset>(params), cache_(cache) {}
|
: DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
|
||||||
|
|
||||||
~MemoryIterator() override {
|
|
||||||
if (dataset()->cache_ == nullptr) {
|
|
||||||
cache_->Unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (cache_ == nullptr) {
|
|
||||||
// Use the resource manager in the iterator context to get / create
|
|
||||||
// a cache.
|
|
||||||
ResourceMgr* mgr = ctx->resource_mgr();
|
|
||||||
const string name = strings::StrCat(
|
|
||||||
prefix(), name_utils::kDelimiter, dataset()->node_name(),
|
|
||||||
name_utils::kDelimiter, kMemoryCache);
|
|
||||||
TF_RETURN_IF_ERROR(mgr->LookupOrCreate<MemoryCache>(
|
|
||||||
kTFData, name, &cache_, [](MemoryCache** cache) {
|
|
||||||
*cache = new MemoryCache();
|
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
InitializeIterator();
|
InitializeIterator();
|
||||||
return iterator_->Initialize(ctx);
|
return iterator_->Initialize(ctx);
|
||||||
}
|
}
|
||||||
|
@ -817,10 +786,10 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
|
class MemoryWriterIterator : public DatasetIterator<MemoryDatasetBase> {
|
||||||
public:
|
public:
|
||||||
explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
|
explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
|
||||||
: DatasetIterator<MemoryDataset>(params), cache_(cache) {}
|
: DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
|
||||||
|
|
||||||
~MemoryWriterIterator() override {
|
~MemoryWriterIterator() override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
@ -900,12 +869,12 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||||
std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_);
|
std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_);
|
||||||
}; // MemoryWriterIterator
|
}; // MemoryWriterIterator
|
||||||
|
|
||||||
class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
|
class MemoryReaderIterator : public DatasetIterator<MemoryDatasetBase> {
|
||||||
public:
|
public:
|
||||||
explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
|
explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
|
||||||
: DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) {
|
: DatasetIterator<MemoryDatasetBase>(params),
|
||||||
CHECK(cache);
|
cache_(cache),
|
||||||
}
|
index_(0) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
// The memory allocated for the cache is owned by the parent
|
// The memory allocated for the cache is owned by the parent
|
||||||
|
@ -988,19 +957,80 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||||
}; // MemoryIterator
|
}; // MemoryIterator
|
||||||
|
|
||||||
const DatasetBase* const input_;
|
const DatasetBase* const input_;
|
||||||
MemoryCache* cache_ = nullptr;
|
MemoryCache* const cache_;
|
||||||
}; // MemoryDataset
|
}; // MemoryDatasetBase
|
||||||
|
|
||||||
class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
|
// This version of memory dataset has an exclusive ownership of the memory cache
|
||||||
|
// resource. It supports sharing of the cache across different iterations of the
|
||||||
|
// `repeat` transformation but not across different iterators.
|
||||||
|
class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
|
||||||
public:
|
public:
|
||||||
explicit MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
MemoryCache* cache,
|
MemoryCache* cache, const ResourceHandle& resource_handle)
|
||||||
std::unique_ptr<OwnedResourceHandle> handle)
|
: MemoryDatasetBase(ctx, input, cache),
|
||||||
: MemoryDataset(ctx, input, cache), handle_(std::move(handle)) {}
|
resource_handle_(resource_handle) {
|
||||||
|
cleanup_ = [this, mgr = ctx->resource_manager()]() {
|
||||||
|
DCHECK(cache_->RefCountIsOne());
|
||||||
|
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
|
||||||
|
resource_handle_.name());
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
Status CheckExternalState() const override {
|
~MemoryDataset() override {
|
||||||
return errors::FailedPrecondition(DebugString(),
|
cache_->Unref();
|
||||||
" depends on memory cache resource.");
|
cleanup_();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const override {
|
||||||
|
Node* input_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
|
||||||
|
Node* filename_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
b->AddDataset(this, {input_node, filename_node}, output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<void()> cleanup_;
|
||||||
|
const ResourceHandle resource_handle_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// This version of memory dataset has a shared ownership of the memory cache
|
||||||
|
// resource. It supports sharing of the cache across different iterations of
|
||||||
|
// the `repeat` transformation and also across different iterators.
|
||||||
|
class CacheDatasetOp::MemoryDatasetV2
|
||||||
|
: public CacheDatasetOp::MemoryDatasetBase {
|
||||||
|
public:
|
||||||
|
MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
|
MemoryCache* cache, const ResourceHandle& resource_handle)
|
||||||
|
: MemoryDatasetBase(ctx, input, cache),
|
||||||
|
resource_handle_(std::move(resource_handle)) {
|
||||||
|
cleanup_ = [this, mgr = ctx->resource_manager()]() {
|
||||||
|
if (cache_->RefCountIsOne()) {
|
||||||
|
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
|
||||||
|
resource_handle_.name());
|
||||||
|
if (!s.ok()) {
|
||||||
|
if (errors::IsNotFound(s)) {
|
||||||
|
// This is a bening race resulting from concurrent deletion.
|
||||||
|
VLOG(1) << "Failed to delete cache resource: " << s.ToString();
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
~MemoryDatasetV2() override {
|
||||||
|
cache_->Unref();
|
||||||
|
cleanup_();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -1013,7 +1043,7 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
|
TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
|
||||||
Node* resource_handle_node = nullptr;
|
Node* resource_handle_node = nullptr;
|
||||||
Tensor handle(DT_RESOURCE, TensorShape({}));
|
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||||
handle.scalar<ResourceHandle>()() = handle_->handle();
|
handle.scalar<ResourceHandle>()() = resource_handle_;
|
||||||
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
|
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
|
||||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||||
this, {input_node, filename_node, resource_handle_node}, output));
|
this, {input_node, filename_node, resource_handle_node}, output));
|
||||||
|
@ -1021,7 +1051,8 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<OwnedResourceHandle> handle_;
|
std::function<void()> cleanup_;
|
||||||
|
const ResourceHandle resource_handle_;
|
||||||
};
|
};
|
||||||
|
|
||||||
CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
|
CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
|
||||||
|
@ -1033,22 +1064,39 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
// Parse out the filenames tensor.
|
// Parse out the filenames tensor.
|
||||||
tstring filename;
|
tstring filename;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename));
|
||||||
|
|
||||||
if (filename.empty()) {
|
if (filename.empty()) {
|
||||||
|
static std::atomic<int64> resource_id_counter(0);
|
||||||
|
const string& container = ctx->resource_manager()->default_container();
|
||||||
|
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
|
||||||
|
resource_id_counter.fetch_add(1));
|
||||||
if (op_version_ == 2) {
|
if (op_version_ == 2) {
|
||||||
MemoryCache* cache = nullptr;
|
MemoryCache* cache = nullptr;
|
||||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &cache));
|
auto handle = HandleFromInput(ctx, 2);
|
||||||
|
Status s = ctx->resource_manager()->Lookup<MemoryCache>(
|
||||||
// Create a fresh handle for the resource because the input handle can
|
handle.container(), handle.name(), &cache);
|
||||||
// become invalid after this op executes.
|
if (errors::IsNotFound(s)) {
|
||||||
std::unique_ptr<OwnedResourceHandle> handle;
|
OP_REQUIRES_OK(ctx,
|
||||||
OP_REQUIRES_OK(
|
ctx->resource_manager()->LookupOrCreate<MemoryCache>(
|
||||||
ctx, OwnedResourceHandle::Create(ctx, cache, kMemoryCache, &handle));
|
container, name, &cache, [](MemoryCache** cache) {
|
||||||
|
*cache = new MemoryCache();
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(ctx, s);
|
||||||
|
}
|
||||||
// Ownership of cache is transferred onto `MemoryDatasetV2`.
|
// Ownership of cache is transferred onto `MemoryDatasetV2`.
|
||||||
*output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
|
*output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
|
||||||
} else {
|
} else {
|
||||||
*output = new MemoryDataset(ctx, input, /*cache=*/nullptr);
|
MemoryCache* cache;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate<MemoryCache>(
|
||||||
|
container, name, &cache, [](MemoryCache** cache) {
|
||||||
|
*cache = new MemoryCache();
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
auto handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
|
||||||
|
// Ownership of cache is transferred onto `MemoryDataset`.
|
||||||
|
*output = new MemoryDataset(ctx, input, cache, handle);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (op_version_ == 2) {
|
if (op_version_ == 2) {
|
||||||
|
|
|
@ -22,8 +22,8 @@ namespace data {
|
||||||
|
|
||||||
class CacheDatasetOp : public UnaryDatasetOpKernel {
|
class CacheDatasetOp : public UnaryDatasetOpKernel {
|
||||||
public:
|
public:
|
||||||
class FileDataset;
|
class FileDatasetBase;
|
||||||
class MemoryDataset;
|
class MemoryDatasetBase;
|
||||||
|
|
||||||
static constexpr const char* const kDatasetType = "Cache";
|
static constexpr const char* const kDatasetType = "Cache";
|
||||||
static constexpr const char* const kInputDataset = "input_dataset";
|
static constexpr const char* const kInputDataset = "input_dataset";
|
||||||
|
@ -38,10 +38,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
|
||||||
DatasetBase** output) override;
|
DatasetBase** output) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
class FileDataset;
|
||||||
class FileDatasetV2;
|
class FileDatasetV2;
|
||||||
|
class MemoryDataset;
|
||||||
class MemoryDatasetV2;
|
class MemoryDatasetV2;
|
||||||
|
|
||||||
int op_version_;
|
const int op_version_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
#include "tensorflow/core/lib/random/philox_random.h"
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
@ -26,7 +27,7 @@ namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const char kMemoryCache[] = "MemoryCache";
|
constexpr char kMemoryCache[] = "MemoryCache";
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -82,10 +83,11 @@ Status AnonymousMemoryCacheHandleOp::CreateResource(
|
||||||
|
|
||||||
void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) {
|
void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) {
|
||||||
const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
|
const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
|
||||||
// The resource is guaranteed to exist because the variant tensor wrapping the
|
// The resource might have been already deleted by the dataset.
|
||||||
// deleter is provided as an unused input to this op, which guarantees that it
|
Status s = ctx->resource_manager()->Delete(handle);
|
||||||
// has not run yet.
|
if (!errors::IsNotFound(s)) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Delete(handle));
|
OP_REQUIRES_OK(ctx, s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -96,6 +98,9 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousMemoryCache").Device(DEVICE_CPU),
|
||||||
REGISTER_KERNEL_BUILDER(Name("DeleteMemoryCache").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("DeleteMemoryCache").Device(DEVICE_CPU),
|
||||||
DeleteMemoryCacheOp);
|
DeleteMemoryCacheOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("DummyMemoryCache").Device(DEVICE_CPU),
|
||||||
|
DummyResourceOp<MemoryCache>);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
@ -286,6 +286,28 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
|
||||||
std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
|
std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
|
||||||
std::function<void(std::function<void()>)> runner, int max_parallelism);
|
std::function<void(std::function<void()>)> runner, int max_parallelism);
|
||||||
|
|
||||||
|
// Op for creating a typed dummy resource.
|
||||||
|
//
|
||||||
|
// This op is used to provide a resource "placeholder" for ops such as
|
||||||
|
// `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
|
||||||
|
// Originally, the lifetime of the resources passed into these ops was managed
|
||||||
|
// externally. After the implementation changed to manage the lifetime of the
|
||||||
|
// resources (including creation) by the ops themselves, the resource input is
|
||||||
|
// only needed to pass a resource handle through graph rewrites. When they are
|
||||||
|
// invoked from user code, the implementation passes in a dummy resource.
|
||||||
|
template <typename ResourceType>
|
||||||
|
class DummyResourceOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
Tensor* tensor;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
|
||||||
|
tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
|
||||||
|
ctx, /*container=*/"", /*name=*/"dummy_resource");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|
|
@ -504,6 +504,13 @@ REGISTER_OP("DeleteMemoryCache")
|
||||||
.Input("deleter: variant")
|
.Input("deleter: variant")
|
||||||
.SetShapeFn(shape_inference::NoOutputs);
|
.SetShapeFn(shape_inference::NoOutputs);
|
||||||
|
|
||||||
|
REGISTER_OP("DummyMemoryCache")
|
||||||
|
.Output("handle: resource")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("CacheDataset")
|
REGISTER_OP("CacheDataset")
|
||||||
.Input("input_dataset: variant")
|
.Input("input_dataset: variant")
|
||||||
.Input("filename: string")
|
.Input("filename: string")
|
||||||
|
|
|
@ -244,8 +244,6 @@ class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(0))
|
dataset_ops.Dataset.from_tensor_slices(components).repeat(0))
|
||||||
cache_dataset = repeat_dataset.cache()
|
cache_dataset = repeat_dataset.cache()
|
||||||
|
|
||||||
# Create initialization ops for iterators without and with
|
|
||||||
# caching, respectively.
|
|
||||||
self.assertDatasetProduces(cache_dataset, expected_output=[])
|
self.assertDatasetProduces(cache_dataset, expected_output=[])
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
|
|
@ -3526,6 +3526,8 @@ class RangeDataset(DatasetSource):
|
||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
|
# This can be deleted after the forward compatibility window for switching
|
||||||
|
# to using dummy resource expires on 5/20.
|
||||||
class _MemoryCacheDeleter(object):
|
class _MemoryCacheDeleter(object):
|
||||||
"""An object which cleans up an anonymous memory cache resource.
|
"""An object which cleans up an anonymous memory cache resource.
|
||||||
|
|
||||||
|
@ -3552,15 +3554,20 @@ class _MemoryCacheDeleter(object):
|
||||||
handle=self._handle, deleter=self._deleter)
|
handle=self._handle, deleter=self._deleter)
|
||||||
|
|
||||||
|
|
||||||
|
# This can be deleted after the forward compatibility window for switching
|
||||||
|
# to using dummy resource expires on 5/20.
|
||||||
class _MemoryCache(object):
|
class _MemoryCache(object):
|
||||||
"""Represents a memory cache resource."""
|
"""Represents a memory cache resource."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(_MemoryCache, self).__init__()
|
super(_MemoryCache, self).__init__()
|
||||||
self._device = context.context().device_name
|
if compat.forward_compatible(2020, 5, 20):
|
||||||
self._handle, self._deleter = (gen_dataset_ops.anonymous_memory_cache())
|
self._handle = gen_dataset_ops.dummy_memory_cache()
|
||||||
self._resource_deleter = _MemoryCacheDeleter(
|
else:
|
||||||
handle=self._handle, device=self._device, deleter=self._deleter)
|
self._device = context.context().device_name
|
||||||
|
self._handle, self._deleter = gen_dataset_ops.anonymous_memory_cache()
|
||||||
|
self._resource_deleter = _MemoryCacheDeleter(
|
||||||
|
handle=self._handle, device=self._device, deleter=self._deleter)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
|
|
|
@ -1176,6 +1176,10 @@ tf_module {
|
||||||
name: "DrawBoundingBoxesV2"
|
name: "DrawBoundingBoxesV2"
|
||||||
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DummyMemoryCache"
|
||||||
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DynamicPartition"
|
name: "DynamicPartition"
|
||||||
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|
|
@ -1176,6 +1176,10 @@ tf_module {
|
||||||
name: "DrawBoundingBoxesV2"
|
name: "DrawBoundingBoxesV2"
|
||||||
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DummyMemoryCache"
|
||||||
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DynamicPartition"
|
name: "DynamicPartition"
|
||||||
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|
Loading…
Reference in New Issue