From 7ebbab819e736319ec35b48e31f4d62fbad6626b Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Mon, 20 Apr 2020 13:51:58 -0700 Subject: [PATCH] [tf.data] This CL changes how the in-memory cache resource is managed, making it possible for the cache dataset to support both a) sharing of the cache across iterators and b) serialization. As a consequence, this CL enables sharing of the cache across iterators for tf.distribute and tf.data service use cases (which require serialization support). PiperOrigin-RevId: 307469217 Change-Id: Ia4f4384752609b83cb10078ebeb20dfc6c8a2d8f --- .../base_api/api_def_DummyMemoryCache.pbtxt | 4 + .../optimizers/data/make_stateless.cc | 6 - .../optimizers/data/make_stateless_test.cc | 23 -- .../core/kernels/data/cache_dataset_ops.cc | 240 +++++++++++------- .../core/kernels/data/cache_dataset_ops.h | 8 +- tensorflow/core/kernels/data/cache_ops.cc | 15 +- tensorflow/core/kernels/data/dataset_utils.h | 22 ++ tensorflow/core/ops/dataset_ops.cc | 7 + .../python/data/kernel_tests/cache_test.py | 2 - tensorflow/python/data/ops/dataset_ops.py | 15 +- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 12 files changed, 211 insertions(+), 139 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_DummyMemoryCache.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_DummyMemoryCache.pbtxt b/tensorflow/core/api_def/base_api/api_def_DummyMemoryCache.pbtxt new file mode 100644 index 00000000000..3b940d48bc7 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DummyMemoryCache.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "DummyMemoryCache" + visibility: HIDDEN +} diff --git a/tensorflow/core/grappler/optimizers/data/make_stateless.cc b/tensorflow/core/grappler/optimizers/data/make_stateless.cc index a18ca58f246..bce86078d3a 100644 --- a/tensorflow/core/grappler/optimizers/data/make_stateless.cc +++ b/tensorflow/core/grappler/optimizers/data/make_stateless.cc @@ -27,8 +27,6 @@ namespace tensorflow { namespace grappler { namespace { -constexpr char kCacheDataset[] = "CacheDataset"; -constexpr char kCacheDatasetV2[] = "CacheDatasetV2"; constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration"; constexpr char kShuffleDataset[] = "ShuffleDataset"; constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2"; @@ -55,10 +53,6 @@ Status MakeStateless::OptimizeAndCollectStats(Cluster* cluster, node.add_input(zero_node->name()); // set `reshuffle_each_iteration` attr (*node.mutable_attr())[kReshuffleEachIteration].set_b(true); - } else if (node.op() == kCacheDatasetV2) { - *node.mutable_op() = kCacheDataset; - // remove `cache` input - node.mutable_input()->RemoveLast(); } } diff --git a/tensorflow/core/grappler/optimizers/data/make_stateless_test.cc b/tensorflow/core/grappler/optimizers/data/make_stateless_test.cc index a30b7c63726..5cc6e6a88c6 100644 --- a/tensorflow/core/grappler/optimizers/data/make_stateless_test.cc +++ b/tensorflow/core/grappler/optimizers/data/make_stateless_test.cc @@ -28,29 +28,6 @@ namespace tensorflow { namespace grappler { namespace { -TEST(MakeStateless, Cache) { - using test::function::NDef; - GrapplerItem item; - item.graph = test::function::GDef( - {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), - NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), - NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), - NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), - NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_INT64}}), - NDef("handle", "Const", {}, {{"value", 1}, {"dtype", DT_RESOURCE}}), - graph_tests_utils::MakeCacheV2Node("cache", "range", "filename", - "handle")}, - {}); - - MakeStateless optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("cache", output)); - int index = graph_utils::FindGraphNodeWithName("cache", output); - EXPECT_EQ(output.node(index).op(), "CacheDataset"); - EXPECT_EQ(output.node(index).input_size(), 2); -} - TEST(MakeStateless, Shuffle) { using test::function::NDef; GrapplerItem item; diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 55cef938ba0..556b859c781 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -48,7 +48,6 @@ constexpr char kShardId[] = "shard_id"; constexpr char kCreatedAt[] = "Created at"; constexpr char kMemoryDatasetPrefix[] = "Memory"; constexpr char kMemoryCache[] = "MemoryCache"; -constexpr char kTFData[] = "tf_data"; constexpr char kCacheClaimed[] = "cache_claimed"; constexpr char kCacheSize[] = "cache_size"; constexpr char kCache[] = "cache"; @@ -58,10 +57,10 @@ constexpr char kIndex[] = "index"; constexpr char kImpl[] = "Impl"; constexpr char kCacheDataset[] = "CacheDataset"; -class CacheDatasetOp::FileDataset : public DatasetBase { +class CacheDatasetOp::FileDatasetBase : public DatasetBase { public: - explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input, - string filename, Env* env) + FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input, + string filename, Env* env) : DatasetBase(DatasetContext(ctx)), input_(input), filename_(std::move(filename)), @@ -76,7 +75,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase { DCHECK_EQ(item_index_padding_size_, 7); } - ~FileDataset() override { input_->Unref(); } + ~FileDatasetBase() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string& prefix) const override { @@ -107,17 +106,6 @@ class CacheDatasetOp::FileDataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); - Node* filename = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output)); - return Status::OK(); - } - const DatasetBase* const input_; const tstring filename_; @@ -131,10 +119,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase { tensor_index); } - class FileIterator : public DatasetIterator { + class FileIterator : public DatasetIterator { public: explicit FileIterator(const Params& params) - : DatasetIterator(params) { + : DatasetIterator(params) { if (params.dataset->env_ ->FileExists(MetaFilename(params.dataset->filename_)) .ok()) { @@ -199,7 +187,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase { private: // FileWriterIterator passes through and caches items from the input - // FileDataset. + // FileDatasetBase. // // This iterator is used when the cache directory is not found on disk. It // creates the cache directory, and passes on the underlying iterator's @@ -214,10 +202,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase { // partial cache gets flushed to disk in files with prefix // _ where shard_id is unique for each checkpoint. // When all elements have been produced, these shards get coalesced. - class FileWriterIterator : public DatasetIterator { + class FileWriterIterator : public DatasetIterator { public: explicit FileWriterIterator(const Params& params) - : DatasetIterator(params), + : DatasetIterator(params), cur_index_(0), shard_id_(0), filename_( @@ -483,10 +471,10 @@ class CacheDatasetOp::FileDataset : public DatasetBase { bool iteration_completed_ TF_GUARDED_BY(mu_); }; // FileWriterIterator - class FileReaderIterator : public DatasetIterator { + class FileReaderIterator : public DatasetIterator { public: explicit FileReaderIterator(const Params& params) - : DatasetIterator(params), + : DatasetIterator(params), cur_index_(0), reader_(dataset()->env_, dataset()->filename_), iterator_restored_(false) {} @@ -606,14 +594,31 @@ class CacheDatasetOp::FileDataset : public DatasetBase { static constexpr size_t kMaxItems = 10000000; // 10 million const size_t item_index_padding_size_; const string tensor_format_string_; -}; // FileDataset +}; // FileDatasetBase -class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset { +class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase { + public: + using FileDatasetBase::FileDatasetBase; + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); + Node* filename = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output)); + return Status::OK(); + } +}; + +class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase { public: explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input, string filename, Env* env, const Tensor& resource_handle) - : FileDataset(ctx, input, filename, env), + : FileDatasetBase(ctx, input, filename, env), resource_handle_(resource_handle) {} protected: @@ -686,20 +691,15 @@ Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache, } // namespace -class CacheDatasetOp::MemoryDataset : public DatasetBase { +class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { public: - explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input, - MemoryCache* cache) + explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input, + MemoryCache* cache) : DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) { input_->Ref(); } - ~MemoryDataset() override { - input_->Unref(); - if (cache_) { - cache_->Unref(); - } - } + ~MemoryDatasetBase() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string& prefix) const override { @@ -732,44 +732,13 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); - Node* filename_node = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_node, filename_node}, output)); - return Status::OK(); - } - - class MemoryIterator : public DatasetIterator { + class MemoryIterator : public DatasetIterator { public: explicit MemoryIterator(const Params& params, MemoryCache* cache) - : DatasetIterator(params), cache_(cache) {} - - ~MemoryIterator() override { - if (dataset()->cache_ == nullptr) { - cache_->Unref(); - } - } + : DatasetIterator(params), cache_(cache) {} Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); - if (cache_ == nullptr) { - // Use the resource manager in the iterator context to get / create - // a cache. - ResourceMgr* mgr = ctx->resource_mgr(); - const string name = strings::StrCat( - prefix(), name_utils::kDelimiter, dataset()->node_name(), - name_utils::kDelimiter, kMemoryCache); - TF_RETURN_IF_ERROR(mgr->LookupOrCreate( - kTFData, name, &cache_, [](MemoryCache** cache) { - *cache = new MemoryCache(); - return Status::OK(); - })); - } InitializeIterator(); return iterator_->Initialize(ctx); } @@ -817,10 +786,10 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { } private: - class MemoryWriterIterator : public DatasetIterator { + class MemoryWriterIterator : public DatasetIterator { public: explicit MemoryWriterIterator(const Params& params, MemoryCache* cache) - : DatasetIterator(params), cache_(cache) {} + : DatasetIterator(params), cache_(cache) {} ~MemoryWriterIterator() override { mutex_lock l(mu_); @@ -900,12 +869,12 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { std::vector> temp_cache_ TF_GUARDED_BY(mu_); }; // MemoryWriterIterator - class MemoryReaderIterator : public DatasetIterator { + class MemoryReaderIterator : public DatasetIterator { public: explicit MemoryReaderIterator(const Params& params, MemoryCache* cache) - : DatasetIterator(params), cache_(cache), index_(0) { - CHECK(cache); - } + : DatasetIterator(params), + cache_(cache), + index_(0) {} Status Initialize(IteratorContext* ctx) override { // The memory allocated for the cache is owned by the parent @@ -988,19 +957,80 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { }; // MemoryIterator const DatasetBase* const input_; - MemoryCache* cache_ = nullptr; -}; // MemoryDataset + MemoryCache* const cache_; +}; // MemoryDatasetBase -class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset { +// This version of memory dataset has an exclusive ownership of the memory cache +// resource. It supports sharing of the cache across different iterations of the +// `repeat` transformation but not across different iterators. +class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase { public: - explicit MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input, - MemoryCache* cache, - std::unique_ptr handle) - : MemoryDataset(ctx, input, cache), handle_(std::move(handle)) {} + MemoryDataset(OpKernelContext* ctx, const DatasetBase* input, + MemoryCache* cache, const ResourceHandle& resource_handle) + : MemoryDatasetBase(ctx, input, cache), + resource_handle_(resource_handle) { + cleanup_ = [this, mgr = ctx->resource_manager()]() { + DCHECK(cache_->RefCountIsOne()); + Status s = mgr->Delete(resource_handle_.container(), + resource_handle_.name()); + if (!s.ok()) { + LOG(WARNING) << "Failed to delete cache resource: " << s.ToString(); + } + }; + } - Status CheckExternalState() const override { - return errors::FailedPrecondition(DebugString(), - " depends on memory cache resource."); + ~MemoryDataset() override { + cache_->Unref(); + cleanup_(); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); + Node* filename_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_node, filename_node}, output)); + return Status::OK(); + } + + private: + std::function cleanup_; + const ResourceHandle resource_handle_; +}; + +// This version of memory dataset has a shared ownership of the memory cache +// resource. It supports sharing of the cache across different iterations of +// the `repeat` transformation and also across different iterators. +class CacheDatasetOp::MemoryDatasetV2 + : public CacheDatasetOp::MemoryDatasetBase { + public: + MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input, + MemoryCache* cache, const ResourceHandle& resource_handle) + : MemoryDatasetBase(ctx, input, cache), + resource_handle_(std::move(resource_handle)) { + cleanup_ = [this, mgr = ctx->resource_manager()]() { + if (cache_->RefCountIsOne()) { + Status s = mgr->Delete(resource_handle_.container(), + resource_handle_.name()); + if (!s.ok()) { + if (errors::IsNotFound(s)) { + // This is a bening race resulting from concurrent deletion. + VLOG(1) << "Failed to delete cache resource: " << s.ToString(); + } else { + LOG(WARNING) << "Failed to delete cache resource: " << s.ToString(); + } + } + } + }; + } + + ~MemoryDatasetV2() override { + cache_->Unref(); + cleanup_(); } protected: @@ -1013,7 +1043,7 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset { TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node)); Node* resource_handle_node = nullptr; Tensor handle(DT_RESOURCE, TensorShape({})); - handle.scalar()() = handle_->handle(); + handle.scalar()() = resource_handle_; TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node)); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_node, filename_node, resource_handle_node}, output)); @@ -1021,7 +1051,8 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset { } private: - std::unique_ptr handle_; + std::function cleanup_; + const ResourceHandle resource_handle_; }; CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx) @@ -1033,22 +1064,39 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, // Parse out the filenames tensor. tstring filename; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kFileName, &filename)); - if (filename.empty()) { + static std::atomic resource_id_counter(0); + const string& container = ctx->resource_manager()->default_container(); + auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_", + resource_id_counter.fetch_add(1)); if (op_version_ == 2) { MemoryCache* cache = nullptr; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 2), &cache)); - - // Create a fresh handle for the resource because the input handle can - // become invalid after this op executes. - std::unique_ptr handle; - OP_REQUIRES_OK( - ctx, OwnedResourceHandle::Create(ctx, cache, kMemoryCache, &handle)); - + auto handle = HandleFromInput(ctx, 2); + Status s = ctx->resource_manager()->Lookup( + handle.container(), handle.name(), &cache); + if (errors::IsNotFound(s)) { + OP_REQUIRES_OK(ctx, + ctx->resource_manager()->LookupOrCreate( + container, name, &cache, [](MemoryCache** cache) { + *cache = new MemoryCache(); + return Status::OK(); + })); + handle = MakeResourceHandle(ctx, container, name); + } else { + OP_REQUIRES_OK(ctx, s); + } // Ownership of cache is transferred onto `MemoryDatasetV2`. *output = new MemoryDatasetV2(ctx, input, cache, std::move(handle)); } else { - *output = new MemoryDataset(ctx, input, /*cache=*/nullptr); + MemoryCache* cache; + OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate( + container, name, &cache, [](MemoryCache** cache) { + *cache = new MemoryCache(); + return Status::OK(); + })); + auto handle = MakeResourceHandle(ctx, container, name); + // Ownership of cache is transferred onto `MemoryDataset`. + *output = new MemoryDataset(ctx, input, cache, handle); } } else { if (op_version_ == 2) { diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.h b/tensorflow/core/kernels/data/cache_dataset_ops.h index 484d0489336..e0ceee2a253 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.h +++ b/tensorflow/core/kernels/data/cache_dataset_ops.h @@ -22,8 +22,8 @@ namespace data { class CacheDatasetOp : public UnaryDatasetOpKernel { public: - class FileDataset; - class MemoryDataset; + class FileDatasetBase; + class MemoryDatasetBase; static constexpr const char* const kDatasetType = "Cache"; static constexpr const char* const kInputDataset = "input_dataset"; @@ -38,10 +38,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { DatasetBase** output) override; private: + class FileDataset; class FileDatasetV2; + class MemoryDataset; class MemoryDatasetV2; - int op_version_; + const int op_version_; }; } // namespace data diff --git a/tensorflow/core/kernels/data/cache_ops.cc b/tensorflow/core/kernels/data/cache_ops.cc index 371a2ae5d25..8b58e7b9e45 100644 --- a/tensorflow/core/kernels/data/cache_ops.cc +++ b/tensorflow/core/kernels/data/cache_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random_distributions.h" @@ -26,7 +27,7 @@ namespace tensorflow { namespace data { namespace { -const char kMemoryCache[] = "MemoryCache"; +constexpr char kMemoryCache[] = "MemoryCache"; } // namespace @@ -82,10 +83,11 @@ Status AnonymousMemoryCacheHandleOp::CreateResource( void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) { const ResourceHandle& handle = ctx->input(0).flat()(0); - // The resource is guaranteed to exist because the variant tensor wrapping the - // deleter is provided as an unused input to this op, which guarantees that it - // has not run yet. - OP_REQUIRES_OK(ctx, ctx->resource_manager()->Delete(handle)); + // The resource might have been already deleted by the dataset. + Status s = ctx->resource_manager()->Delete(handle); + if (!errors::IsNotFound(s)) { + OP_REQUIRES_OK(ctx, s); + } } namespace { @@ -96,6 +98,9 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousMemoryCache").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DeleteMemoryCache").Device(DEVICE_CPU), DeleteMemoryCacheOp); +REGISTER_KERNEL_BUILDER(Name("DummyMemoryCache").Device(DEVICE_CPU), + DummyResourceOp); + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index bedd5facda9..eb7316dd348 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -286,6 +286,28 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base, std::function)> RunnerWithMaxParallelism( std::function)> runner, int max_parallelism); +// Op for creating a typed dummy resource. +// +// This op is used to provide a resource "placeholder" for ops such as +// `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input. +// Originally, the lifetime of the resources passed into these ops was managed +// externally. After the implementation changed to manage the lifetime of the +// resources (including creation) by the ops themselves, the resource input is +// only needed to pass a resource handle through graph rewrites. When they are +// invoked from user code, the implementation passes in a dummy resource. +template +class DummyResourceOp : public OpKernel { + public: + explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Tensor* tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor)); + tensor->scalar()() = MakeResourceHandle( + ctx, /*container=*/"", /*name=*/"dummy_resource"); + } +}; + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 74e0d5bcf84..eed9b2a4d6a 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -504,6 +504,13 @@ REGISTER_OP("DeleteMemoryCache") .Input("deleter: variant") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("DummyMemoryCache") + .Output("handle: resource") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + REGISTER_OP("CacheDataset") .Input("input_dataset: variant") .Input("filename: string") diff --git a/tensorflow/python/data/kernel_tests/cache_test.py b/tensorflow/python/data/kernel_tests/cache_test.py index 00068a7fd4c..a95424b6843 100644 --- a/tensorflow/python/data/kernel_tests/cache_test.py +++ b/tensorflow/python/data/kernel_tests/cache_test.py @@ -244,8 +244,6 @@ class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase): dataset_ops.Dataset.from_tensor_slices(components).repeat(0)) cache_dataset = repeat_dataset.cache() - # Create initialization ops for iterators without and with - # caching, respectively. self.assertDatasetProduces(cache_dataset, expected_output=[]) @combinations.generate(test_base.default_test_combinations()) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 2eddeb6eac6..7dcec3248ce 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -3526,6 +3526,8 @@ class RangeDataset(DatasetSource): return self._structure +# This can be deleted after the forward compatibility window for switching +# to using dummy resource expires on 5/20. class _MemoryCacheDeleter(object): """An object which cleans up an anonymous memory cache resource. @@ -3552,15 +3554,20 @@ class _MemoryCacheDeleter(object): handle=self._handle, deleter=self._deleter) +# This can be deleted after the forward compatibility window for switching +# to using dummy resource expires on 5/20. class _MemoryCache(object): """Represents a memory cache resource.""" def __init__(self): super(_MemoryCache, self).__init__() - self._device = context.context().device_name - self._handle, self._deleter = (gen_dataset_ops.anonymous_memory_cache()) - self._resource_deleter = _MemoryCacheDeleter( - handle=self._handle, device=self._device, deleter=self._deleter) + if compat.forward_compatible(2020, 5, 20): + self._handle = gen_dataset_ops.dummy_memory_cache() + else: + self._device = context.context().device_name + self._handle, self._deleter = gen_dataset_ops.anonymous_memory_cache() + self._resource_deleter = _MemoryCacheDeleter( + handle=self._handle, device=self._device, deleter=self._deleter) @property def handle(self): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 2fe82b3bc72..7c940c475c6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1176,6 +1176,10 @@ tf_module { name: "DrawBoundingBoxesV2" argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DummyMemoryCache" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DynamicPartition" argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 2fe82b3bc72..7c940c475c6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1176,6 +1176,10 @@ tf_module { name: "DrawBoundingBoxesV2" argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DummyMemoryCache" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DynamicPartition" argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "