diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 7ea5b753da6..950fbd4c516 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -224,7 +224,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase { lockfile_created_(false), iteration_completed_(false) {} - ~FileWriterIterator() { + ~FileWriterIterator() override { if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) { std::vector cache_files; Status s = dataset()->env_->GetMatchingPaths( @@ -630,6 +630,57 @@ class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset { const Tensor resource_handle_; }; +namespace { +template +Status SaveCache(IteratorStateWriter* writer, T* cache, FullNameFn full_name) { + size_t cache_size = cache->size(); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheSize), cache_size)); + for (size_t i = 0; i < cache_size; i++) { + auto& element = cache->at(i); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), + element.size())); + for (size_t j = 0; j < element.size(); ++j) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat(kCache, "[", i, "][", j, "]")), + element[j])); + } + } + return Status::OK(); +} + +template +Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache, + FullNameFn full_name) { + size_t cache_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp)); + cache_size = static_cast(temp); + } + for (size_t i = 0; i < cache_size; ++i) { + std::vector element; + size_t element_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), &temp)); + element_size = static_cast(temp); + } + element.reserve(element_size); + for (size_t j = 0; j < element_size; ++j) { + element.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat(kCache, "[", i, "][", j, "]")), + &element.back())); + } + cache->emplace_back(std::move(element)); + } + return Status::OK(); +} + +} // namespace + class CacheDatasetOp::MemoryDataset : public DatasetBase { public: explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input, @@ -714,12 +765,7 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { return Status::OK(); })); } - mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read; InitializeIterator(); - if (mode_ == Mode::read && !cache_->IsCompleted()) { - return errors::Internal( - "Cache should only be read after it has been completed."); - } return iterator_->Initialize(ctx); } @@ -739,27 +785,10 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_)); - if (cache_->IsClaimed()) { - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheClaimed), "")); - size_t cache_size = cache_->size(); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(kCacheSize), cache_size)); - for (size_t i = 0; i < cache_size; i++) { - auto& element = cache_->at(i); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), - element.size())); - for (size_t j = 0; j < element.size(); ++j) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat(kCache, "[", i, "][", j, "]")), - element[j])); - } - } - if (cache_->IsCompleted()) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(kCacheCompleted), "")); - } + if (cache_->IsCompleted()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), "")); + TF_RETURN_IF_ERROR(SaveCache( + writer, cache_, [this](const string& s) { return full_name(s); })); } return SaveInput(writer, iterator_); } @@ -769,41 +798,12 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { mutex_lock l(mu_); iterator_.reset(); cache_->Reset(); - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp)); - mode_ = static_cast(temp); - } - if (reader->Contains(full_name(kCacheClaimed))) { - CHECK(cache_->MaybeClaim()); - size_t cache_size; - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp)); - cache_size = static_cast(temp); - } - for (size_t i = 0; i < cache_size; ++i) { - std::vector element; - size_t element_size; - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), - &temp)); - element_size = static_cast(temp); - } - element.reserve(element_size); - for (size_t j = 0; j < element_size; ++j) { - element.emplace_back(); - TF_RETURN_IF_ERROR(reader->ReadTensor( - full_name(strings::StrCat(kCache, "[", i, "][", j, "]")), - &element.back())); - } - cache_->emplace_back(std::move(element)); - } - if (reader->Contains(full_name(kCacheCompleted))) { - cache_->Complete(); - } + if (reader->Contains(full_name(kCacheCompleted))) { + std::vector> temp_cache; + TF_RETURN_IF_ERROR( + RestoreCache(ctx, reader, &temp_cache, + [this](const string& s) { return full_name(s); })); + cache_->Complete(std::move(temp_cache)); } InitializeIterator(); TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); @@ -814,13 +814,11 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { class MemoryWriterIterator : public DatasetIterator { public: explicit MemoryWriterIterator(const Params& params, MemoryCache* cache) - : DatasetIterator(params), cache_(cache) { - CHECK(cache_); - } + : DatasetIterator(params), cache_(cache) {} ~MemoryWriterIterator() override { mutex_lock l(mu_); - if (cache_->size() > 0 && !cache_->IsCompleted()) { + if (!temp_cache_.empty() && !cache_->IsCompleted()) { LOG(WARNING) << "The calling iterator did not fully read the dataset being " "cached. In order to avoid unexpected truncation of the " @@ -843,11 +841,11 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); if (*end_of_sequence) { - cache_->Complete(); + cache_->Complete(std::move(temp_cache_)); return Status::OK(); } RecordBufferEnqueue(ctx, *out_tensors); - cache_->emplace_back(*out_tensors); + temp_cache_.emplace_back(*out_tensors); return Status::OK(); } @@ -860,12 +858,22 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); + if (!cache_->IsCompleted()) { + TF_RETURN_IF_ERROR( + SaveCache(writer, &temp_cache_, + [this](const string& s) { return full_name(s); })); + } return SaveInput(writer, input_impl_); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); + if (!reader->Contains(full_name(kCacheCompleted))) { + TF_RETURN_IF_ERROR( + RestoreCache(ctx, reader, &temp_cache_, + [this](const string& s) { return full_name(s); })); + } return RestoreInput(ctx, reader, input_impl_); } @@ -873,7 +881,8 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { mutex mu_; std::unique_ptr input_impl_ GUARDED_BY(mu_); MemoryCache* const cache_ GUARDED_BY(mu_); // not owned. - }; // MemoryWriterIterator + std::vector> temp_cache_ GUARDED_BY(mu_); + }; // MemoryWriterIterator class MemoryReaderIterator : public DatasetIterator { public: @@ -943,25 +952,21 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase { }; // MemoryReaderIterator void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - switch (mode_) { - case Mode::read: - iterator_ = absl::make_unique( - MemoryReaderIterator::Params{dataset(), - strings::StrCat(prefix(), kImpl)}, - cache_); - break; - case Mode::write: - iterator_ = absl::make_unique( - MemoryWriterIterator::Params{dataset(), - strings::StrCat(prefix(), kImpl)}, - cache_); + if (cache_->IsCompleted()) { + iterator_ = absl::make_unique( + MemoryReaderIterator::Params{dataset(), + strings::StrCat(prefix(), kImpl)}, + cache_); + } else { + iterator_ = absl::make_unique( + MemoryWriterIterator::Params{dataset(), + strings::StrCat(prefix(), kImpl)}, + cache_); } } mutex mu_; MemoryCache* cache_ GUARDED_BY(mu_); // not owned. - enum Mode { read, write }; - Mode mode_ GUARDED_BY(mu_); std::unique_ptr iterator_ GUARDED_BY(mu_); }; // MemoryIterator diff --git a/tensorflow/core/kernels/data/cache_ops.cc b/tensorflow/core/kernels/data/cache_ops.cc index 2d77e0378f7..e4e1be28a97 100644 --- a/tensorflow/core/kernels/data/cache_ops.cc +++ b/tensorflow/core/kernels/data/cache_ops.cc @@ -32,14 +32,12 @@ const char kMemoryCache[] = "MemoryCache"; string MemoryCache::DebugString() const { return kMemoryCache; } -void MemoryCache::Complete() { +void MemoryCache::Complete(std::vector>&& cache) { mutex_lock l(mu_); - completed_ = true; -} - -bool MemoryCache::IsClaimed() { - tf_shared_lock l(mu_); - return claimed_; + if (!completed_) { + cache_ = std::move(cache); + completed_ = true; + } } bool MemoryCache::IsCompleted() { @@ -47,18 +45,8 @@ bool MemoryCache::IsCompleted() { return completed_; } -bool MemoryCache::MaybeClaim() { - mutex_lock l(mu_); - if (!claimed_) { - claimed_ = true; - return true; - } - return false; -} - void MemoryCache::Reset() { mutex_lock l(mu_); - claimed_ = false; completed_ = false; cache_.clear(); } @@ -69,11 +57,6 @@ const std::vector& MemoryCache::at(int64 index) { return cache_[index]; } -void MemoryCache::emplace_back(std::vector element) { - mutex_lock l(mu_); - cache_.emplace_back(std::move(element)); -} - size_t MemoryCache::size() { tf_shared_lock l(mu_); return cache_.size(); diff --git a/tensorflow/core/kernels/data/cache_ops.h b/tensorflow/core/kernels/data/cache_ops.h index c022c06f291..2cb6eb6e8ed 100644 --- a/tensorflow/core/kernels/data/cache_ops.h +++ b/tensorflow/core/kernels/data/cache_ops.h @@ -34,33 +34,22 @@ class MemoryCache : public ResourceBase { string DebugString() const override; // Marks the cache as completed. - void Complete(); - - // Returns whether the cache is claimed. - bool IsClaimed(); + void Complete(std::vector>&& cache); // Returns whether the cache is completed. bool IsCompleted(); - // Attempts to claim the cache, returning whether the cache was claimed. - bool MaybeClaim(); - // Resets the cache. void Reset(); // Returns the element at the given index. const std::vector& at(int64 index); - // Adds the element to the cache. - void emplace_back(std::vector element); - // Returns the size of the cache. size_t size(); private: mutex mu_; - // Determines whether a writer has claimed the cache. - bool claimed_ GUARDED_BY(mu_) = false; // Determines whether all elements of the dataset have been cached. bool completed_ GUARDED_BY(mu_) = false; std::vector> cache_ GUARDED_BY(mu_); diff --git a/tensorflow/python/data/kernel_tests/cache_test.py b/tensorflow/python/data/kernel_tests/cache_test.py index a7df11464ff..6d3dc04a3e0 100644 --- a/tensorflow/python/data/kernel_tests/cache_test.py +++ b/tensorflow/python/data/kernel_tests/cache_test.py @@ -352,6 +352,18 @@ class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertAllEqual(results, range(10)) + @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) + def testCacheV2ConcurrentIterators(self): + + dataset = dataset_ops.Dataset.range(10).cache() + + it1 = iter(dataset) + it2 = iter(dataset) + + for i in range(10): + self.assertEqual(next(it1), i) + self.assertEqual(next(it2), i) + if __name__ == "__main__": test.main()