diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 119f61d29ca..96bd7f12d48 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1063,7 +1063,9 @@ tf_kernel_library( tf_kernel_library( name = "cache_dataset_ops", srcs = ["cache_dataset_ops.cc"], + hdrs = ["cache_dataset_ops.h"], deps = [ + ":name_utils", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -1072,6 +1074,23 @@ tf_kernel_library( ], ) +tf_cc_test( + name = "cache_dataset_ops_test", + srcs = ["cache_dataset_ops_test.cc"], + deps = [ + ":cache_dataset_ops", + ":dataset_test_base", + ":dataset_utils", + ":iterator_ops", + ":tensor_slice_dataset_op", + "//tensorflow/core:framework", + "//tensorflow/core:ptr_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "optimize_dataset_op", srcs = ["optimize_dataset_op.cc"], diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 341a02cc259..9b1fed90463 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/cache_dataset_ops.h" + #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" @@ -23,117 +25,245 @@ limitations under the License. namespace tensorflow { namespace data { -namespace { // See documentation in ../../ops/dataset_ops.cc for a high-level description of // the following op. -class CacheDatasetOp : public UnaryDatasetOpKernel { +/* static */ constexpr const char* const CacheDatasetOp::kDatasetType; +/* static */ constexpr const char* const CacheDatasetOp::kInputDataset; +/* static */ constexpr const char* const CacheDatasetOp::kFileName; +/* static */ constexpr const char* const CacheDatasetOp::kOutputTypes; +/* static */ constexpr const char* const CacheDatasetOp::kOutputShapes; + +constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu"; +constexpr char kPaddingSizeStrFormat[] = "%zu"; +constexpr char kFileDatasetPrefix[] = "File"; +constexpr char kMode[] = "Mode"; +constexpr char kLockFileSuffix[] = ".lockfile"; +constexpr char kIterationCompleted[] = "iteration_completed"; +constexpr char kCurIndex[] = "cur_index"; +constexpr char kShardId[] = "shard_id"; +constexpr char kCreatedAt[] = "Created at"; +constexpr char kMemoryDatasetPrefix[] = "Memory"; +constexpr char kMemoryCache[] = "MemoryCache"; +constexpr char kTFData[] = "tf_data"; +constexpr char kCacheClaimed[] = "cache_claimed"; +constexpr char kCacheSize[] = "cache_size"; +constexpr char kCache[] = "cache"; +constexpr char kSizeSuffix[] = ".size"; +constexpr char kCacheCompleted[] = "cache_completed"; +constexpr char kIndex[] = "index"; +constexpr char kImpl[] = "Impl"; + +class CacheDatasetOp::FileDataset : public DatasetBase { public: - explicit CacheDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} + explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input, + string filename, Env* env) + : DatasetBase(DatasetContext(ctx)), + input_(input), + filename_(std::move(filename)), + env_(env), + num_tensors_(input->output_dtypes().size()), + tensor_index_padding_size_(StringPaddingSize(num_tensors_)), + item_index_padding_size_(StringPaddingSize(kMaxItems)), + tensor_format_string_(strings::Printf(kKeyStrFormat, + item_index_padding_size_, + tensor_index_padding_size_)) { + input_->Ref(); + DCHECK_EQ(item_index_padding_size_, 7); + } - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - // Parse out the filenames tensor. - string filename; - OP_REQUIRES_OK(ctx, - ParseScalarArgument(ctx, "filename", &filename)); + ~FileDataset() override { input_->Unref(); } - if (filename.empty()) { - *output = new MemoryDataset(ctx, input); - } else { - *output = new FileDataset(ctx, input, filename, ctx->env()); - } + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + name_utils::IteratorPrefixParams params; + params.dataset_prefix = kFileDatasetPrefix; + return absl::make_unique(FileIterator::Params{ + this, name_utils::IteratorPrefix(kDatasetType, prefix, params)}); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + name_utils::DatasetDebugStringParams params; + params.dataset_prefix = kFileDatasetPrefix; + return name_utils::DatasetDebugString(kDatasetType, params); + } + + int64 Cardinality() const override { return input_->Cardinality(); } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); + Node* filename = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output)); + return Status::OK(); } private: - class FileDataset : public DatasetBase { + static size_t StringPaddingSize(size_t num_tensors) { + return strings::Printf(kPaddingSizeStrFormat, num_tensors - 1).size(); + } + + string FormatName(size_t item_index, size_t tensor_index) const { + return strings::Printf(tensor_format_string_.c_str(), item_index, + tensor_index); + } + + class FileIterator : public DatasetIterator { public: - explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input, - string filename, Env* env) - : DatasetBase(DatasetContext(ctx)), - input_(input), - filename_(std::move(filename)), - env_(env), - num_tensors_(input->output_dtypes().size()), - tensor_index_padding_size_(StringPaddingSize(num_tensors_)), - item_index_padding_size_(StringPaddingSize(kMaxItems)), - tensor_format_string_(strings::Printf("%%%zuzu_%%%zuzu", - item_index_padding_size_, - tensor_index_padding_size_)) { - input_->Ref(); - DCHECK_EQ(item_index_padding_size_, 7); + explicit FileIterator(const Params& params) + : DatasetIterator(params) { + if (params.dataset->env_ + ->FileExists(MetaFilename(params.dataset->filename_)) + .ok()) { + mode_ = Mode::read; + } else { + mode_ = Mode::write; + } + InitializeIterator(); } - ~FileDataset() override { input_->Unref(); } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique( - FileIterator::Params{this, strings::StrCat(prefix, "::FileCache")}); + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + return iterator_->Initialize(ctx); } - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + return iterator_->GetNext(ctx, out_tensors, end_of_sequence); } - const std::vector& output_shapes() const override { - return input_->output_shapes(); - } - - string DebugString() const override { - return "CacheDatasetOp::FileDataset"; - } - - int64 Cardinality() const override { return input_->Cardinality(); } - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); - Node* filename = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output)); - return Status::OK(); + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_)); + return SaveInput(writer, iterator_); + } + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp)); + mode_ = static_cast(temp); + } + if (mode_ == Mode::write && + dataset() + ->env_->FileExists(MetaFilename(dataset()->filename_)) + .ok()) { + // This could happen if the cache was completely written after the + // checkpoint was saved. + LOG(WARNING) + << "It looks like the cache was already completely written(" + << MetaFilename(dataset()->filename_) + << ") after the last checkpoint was saved. Attempting to read " + << "the cache instead of continuing to write. If this is a " + << "mistake, please remove the above file and try running again."; + mode_ = Mode::read; + } + InitializeIterator(); + TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); + return RestoreInput(ctx, reader, iterator_); } private: - static size_t StringPaddingSize(size_t num_tensors) { - return strings::Printf("%zu", num_tensors - 1).size(); - } - - string FormatName(size_t item_index, size_t tensor_index) const { - return strings::Printf(tensor_format_string_.c_str(), item_index, - tensor_index); - } - - class FileIterator : public DatasetIterator { + // FileWriterIterator passes through and caches items from the input + // FileDataset. + // + // This iterator is used when the cache directory is not found on disk. It + // creates the cache directory, and passes on the underlying iterator's + // elements. + // + // Caching is performed by writing the input tensors to disk using the + // `BundleWriter`. Note that the cache gets fully flushed to disk only + // after the input iterator has been fully exhausted. If the program + // exits, before completion of an epoch, the cached state would be lost. + // To ensure that the partial cache persists across sessions, one should + // checkpoint the input pipeline. On each call to `SaveInternal` the + // partial cache gets flushed to disk in files with prefix + // _ where shard_id is unique for each checkpoint. + // When all elements have been produced, these shards get coalesced. + class FileWriterIterator : public DatasetIterator { public: - explicit FileIterator(const Params& params) - : DatasetIterator(params) { - if (params.dataset->env_ - ->FileExists(MetaFilename(params.dataset->filename_)) - .ok()) { - mode_ = Mode::read; - } else { - mode_ = Mode::write; - } - InitializeIterator(); - } + explicit FileWriterIterator(const Params& params) + : DatasetIterator(params), + cur_index_(0), + shard_id_(0), + filename_( + strings::StrCat(params.dataset->filename_, "_", shard_id_)), + lockfile_(strings::StrCat(filename_, kLockFileSuffix)), + lockfile_created_(false), + iteration_completed_(false) {} Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - return iterator_->Initialize(ctx); + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - return iterator_->GetNext(ctx, out_tensors, end_of_sequence); + *end_of_sequence = false; + TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence)); + if (*end_of_sequence) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(writer_->status()); + if (cur_index_ >= kMaxItems) { + // As a courtesy, close the [truncated] cache file. + Status s = Finish(); + if (!s.ok()) { + LOG(ERROR) << s; + } + return errors::InvalidArgument( + "Upstream iterator is producing more than ", kMaxItems, + " items, which is more than the cache limit."); + } + + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (*end_of_sequence && out_tensors->empty()) { + TF_RETURN_IF_ERROR(Finish()); + cur_index_++; + return Status::OK(); + } + if (out_tensors->size() != dataset()->num_tensors_) { + return errors::Internal( + "Upstream iterator returned invalid number of tensors. " + "Expected ", + dataset()->num_tensors_, " got: ", out_tensors->size()); + } + size_t tensor_index = 0; + for (const Tensor& t : *out_tensors) { + DCHECK_LT(tensor_index, dataset()->num_tensors_); + string key = dataset()->FormatName(cur_index_, tensor_index++); + TF_RETURN_IF_ERROR(writer_->Add(key, t)); + } + if (*end_of_sequence) { + TF_RETURN_IF_ERROR(Finish()); + } + cur_index_++; + return Status::OK(); } protected: @@ -145,578 +275,219 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_)); - return SaveInput(writer, iterator_); - } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp)); - mode_ = static_cast(temp); - } - if (mode_ == Mode::write && - dataset() - ->env_->FileExists(MetaFilename(dataset()->filename_)) - .ok()) { - // This could happen if the cache was completely written after the - // checkpoint was saved. - LOG(WARNING) - << "It looks like the cache was already completely written(" - << MetaFilename(dataset()->filename_) - << ") after the last checkpoint was saved. Attempting to read " - << "the cache instead of continuing to write. If this is a " - << "mistake, please remove the above file and try running again."; - mode_ = Mode::read; - } - InitializeIterator(); - TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); - return RestoreInput(ctx, reader, iterator_); - } - - private: - // FileWriterIterator passes through and caches items from the input - // FileDataset. - // - // This iterator is used when the cache directory is not found on disk. It - // creates the cache directory, and passes on the underlying iterator's - // elements. - // - // Caching is performed by writing the input tensors to disk using the - // `BundleWriter`. Note that the cache gets fully flushed to disk only - // after the input iterator has been fully exhausted. If the program - // exits, before completion of an epoch, the cached state would be lost. - // To ensure that the partial cache persists across sessions, one should - // checkpoint the input pipeline. On each call to `SaveInternal` the - // partial cache gets flushed to disk in files with prefix - // _ where shard_id is unique for each checkpoint. - // When all elements have been produced, these shards get coalesced. - class FileWriterIterator : public DatasetIterator { - public: - explicit FileWriterIterator(const Params& params) - : DatasetIterator(params), - cur_index_(0), - shard_id_(0), - filename_( - strings::StrCat(params.dataset->filename_, "_", shard_id_)), - lockfile_(strings::StrCat(filename_, ".lockfile")), - lockfile_created_(false), - iteration_completed_(false) {} - - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - *end_of_sequence = false; - TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence)); - if (*end_of_sequence) { - return Status::OK(); - } - TF_RETURN_IF_ERROR(writer_->status()); - if (cur_index_ >= kMaxItems) { - // As a courtesy, close the [truncated] cache file. - Status s = Finish(); - if (!s.ok()) { - LOG(ERROR) << s; - } - return errors::InvalidArgument( - "Upstream iterator is producing more than ", kMaxItems, - " items, which is more than the cache limit."); - } - + if (iteration_completed_) { TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (*end_of_sequence && out_tensors->empty()) { - TF_RETURN_IF_ERROR(Finish()); - cur_index_++; - return Status::OK(); - } - if (out_tensors->size() != dataset()->num_tensors_) { - return errors::Internal( - "Upstream iterator returned invalid number of tensors. " - "Expected ", - dataset()->num_tensors_, " got: ", out_tensors->size()); - } - size_t tensor_index = 0; - for (const Tensor& t : *out_tensors) { - DCHECK_LT(tensor_index, dataset()->num_tensors_); - string key = dataset()->FormatName(cur_index_, tensor_index++); - TF_RETURN_IF_ERROR(writer_->Add(key, t)); - } - if (*end_of_sequence) { - TF_RETURN_IF_ERROR(Finish()); - } - cur_index_++; + writer->WriteScalar(full_name(kIterationCompleted), "")); return Status::OK(); } - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - if (iteration_completed_) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("iteration_completed"), "")); - return Status::OK(); - } - - // lockfile is created on the first call to GetNextInternal. The - // absence of a lockfile means that GetNextInternal was not called - // and hence nothing was written to cache. So we don't need to worry - // about flushing the current shard. This ensures that we never write - // empty shards. - if (lockfile_created_) { - // Flush the current bundle. - TF_RETURN_IF_ERROR(writer_->Finish()); - - // Note: We do not delete the lockfile here. We keep lockfiles of - // all shards around until the entire cache has been written to - // prevent concurrent iterators from corrupting any of the shards. - - // Start caching to a new shard. - shard_id_++; - filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_); - lockfile_ = strings::StrCat(filename_, ".lockfile"); - lockfile_created_ = false; - } - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("cur_index"), cur_index_)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("shard_id"), shard_id_)); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (reader->Contains(full_name("iteration_completed"))) { - iteration_completed_ = true; - return Status::OK(); - } - - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - int64 temp; - // TODO(b/78048575): Update this when saving size_t tensors directly - // is supported. - { - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("cur_index"), &temp)); - cur_index_ = static_cast(temp); - if (cur_index_ != temp) { - return errors::Internal("Invalid value for cur_index ", temp); - } - } - // TODO(b/78048575): Update this when saving size_t tensors directly - // is supported. - { - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("shard_id"), &temp)); - shard_id_ = static_cast(temp); - if (shard_id_ != temp) { - return errors::Internal("Invalid value for shard_id ", temp); - } - } - filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_); - lockfile_ = strings::StrCat(filename_, ".lockfile"); - writer_ = absl::make_unique(dataset()->env_, filename_); - return Status::OK(); - } - - private: - Status EnsureLockFileExists(bool* end_of_sequence) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (iteration_completed_) { - *end_of_sequence = true; - return Status::OK(); - } - if (lockfile_created_ && !iteration_completed_) return Status::OK(); - - // Perform rudimentary locking to help catch concurrent writes to the - // same cache files. - - // 1. Check that a checkpoint for the shard has not already been - // written. - if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) { - return errors::AlreadyExists("Existing cache files found: \n", - MetaFilename(filename_), "\n", - DataFilename(filename_, 0, 1), "\n", - "To continue delete the above files."); - } - - // 2. Check that there isn't a concurrent iterator that is writing - // to cache. - if (dataset()->env_->FileExists(lockfile_).ok()) { - // Attempt to read the contents of the lockfile. - char contents_scratch[151] = {0}; // Initialize all to 0. - StringPiece contents; - std::unique_ptr file; - if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) { - file->Read(0, 150, &contents, contents_scratch).IgnoreError(); - } - return errors::AlreadyExists( - "There appears to be a concurrent caching iterator running - " - "cache lockfile already exists ('", - lockfile_, - "'). If you are sure no other running TF computations are " - "using this cache prefix, delete the lockfile and " - "re-initialize the iterator. Lockfile contents: ", - contents); - } - // Create the file, and write some basic contents. - std::unique_ptr lockfile; - TF_RETURN_IF_ERROR( - dataset()->env_->NewWritableFile(lockfile_, &lockfile)); - TF_RETURN_IF_ERROR(lockfile->Append( - strings::StrCat("Created at: ", dataset()->env_->NowSeconds()))); - - // At this point we know that - // 1. There is no conflicting checkpoint with prefix `filename_`. - // 2. There is no concurrent session that is trying to write a ckpt - // to filename. - // So it is safe to create a BundleWriter here. Note that it is - // unsafe to initialize the BundleWriter anywhere the above - // conditions are not met since BundleWriter's constructor creates - // new temp files which can delete the temp files created by a - // BundleWriter in another Session. - writer_ = absl::make_unique(dataset()->env_, filename_); - lockfile_created_ = true; - return Status::OK(); - } - - Status Finish() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - iteration_completed_ = true; + // lockfile is created on the first call to GetNextInternal. The + // absence of a lockfile means that GetNextInternal was not called + // and hence nothing was written to cache. So we don't need to worry + // about flushing the current shard. This ensures that we never write + // empty shards. + if (lockfile_created_) { // Flush the current bundle. TF_RETURN_IF_ERROR(writer_->Finish()); - // Merge all the bundles. - // Currently there are `shard_id_ + 1` bundles, one for each - // checkpoint. Each bundle has prefix _ where `id` is an - // integer starting at 0 an incremented by 1 for each new checkpoint. - // We merge all these bundles into a bundle with prefix so - // that the next call to `MakeIterator` can build a - // `FileReaderIterator`. - { - std::vector prefixes; - prefixes.reserve(shard_id_ + 1); - for (size_t i = 0; i <= shard_id_; ++i) { - prefixes.emplace_back( - strings::StrCat(dataset()->filename_, "_", i)); - } - TF_RETURN_IF_ERROR( - MergeBundles(dataset()->env_, prefixes, dataset()->filename_)); - } - // Delete all lockfiles. - for (size_t i = 0; i <= shard_id_; ++i) { - TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile( - strings::StrCat(dataset()->filename_, "_", i, ".lockfile"))); - } + + // Note: We do not delete the lockfile here. We keep lockfiles of + // all shards around until the entire cache has been written to + // prevent concurrent iterators from corrupting any of the shards. + + // Start caching to a new shard. + shard_id_++; + filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_); + lockfile_ = strings::StrCat(filename_, kLockFileSuffix); + lockfile_created_ = false; + } + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kCurIndex), cur_index_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kShardId), shard_id_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (reader->Contains(full_name(kIterationCompleted))) { + iteration_completed_ = true; return Status::OK(); } - mutex mu_; - size_t cur_index_ GUARDED_BY(mu_); - // Index of the current shard. This gets incremented whenever a new - // cache shard is saved. - size_t shard_id_ GUARDED_BY(mu_); - std::unique_ptr input_impl_ GUARDED_BY(mu_); - // The current prefix for the cache file. This is equal to - // `StrCat(dataset()->filename_, "_", shard_id_)`. - string filename_; - std::unique_ptr writer_ GUARDED_BY(mu_); - string lockfile_ GUARDED_BY(mu_); - bool lockfile_created_ GUARDED_BY(mu_); - bool iteration_completed_ GUARDED_BY(mu_); - }; // FileWriterIterator - - class FileReaderIterator : public DatasetIterator { - public: - explicit FileReaderIterator(const Params& params) - : DatasetIterator(params), - cur_index_(0), - reader_(dataset()->env_, dataset()->filename_), - iterator_restored_(false) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - *end_of_sequence = false; - TF_RETURN_IF_ERROR(reader_.status()); - if (!reader_.Valid()) { - *end_of_sequence = true; - return Status::OK(); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + int64 temp; + // TODO(b/78048575): Update this when saving size_t tensors directly + // is supported. + { + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &temp)); + cur_index_ = static_cast(temp); + if (cur_index_ != temp) { + return errors::Internal("Invalid value for cur_index ", temp); } - out_tensors->clear(); - out_tensors->resize(dataset()->num_tensors_); - - for (size_t i = 0; i < dataset()->num_tensors_; ++i) { - // When the iterator is restored from the checkpoint, `reader_` is - // already pointing at `key` so we do not need to skip the header - // entry. - if (!iterator_restored_) { - reader_.Next(); // The first entry in the table is a header. - } else { - iterator_restored_ = false; - } - if (!reader_.Valid()) { - out_tensors->clear(); - *end_of_sequence = true; - return Status::OK(); - } - StringPiece key = reader_.key(); - DCHECK_EQ(key, dataset()->FormatName(cur_index_, i)); - TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i])); - TF_RETURN_IF_ERROR(reader_.status()); + } + // TODO(b/78048575): Update this when saving size_t tensors directly + // is supported. + { + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kShardId), &temp)); + shard_id_ = static_cast(temp); + if (shard_id_ != temp) { + return errors::Internal("Invalid value for shard_id ", temp); } - cur_index_++; - return Status::OK(); } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("cur_index"), cur_index_)); - return Status::OK(); - } - - Status RestoreInternal( - IteratorContext* ctx, - IteratorStateReader* iterator_state_reader) override { - mutex_lock l(mu_); - { - // TODO(b/78048575): Update this when saving size_t tensors directly - // is supported. - int64 temp; - TF_RETURN_IF_ERROR(iterator_state_reader->ReadScalar( - full_name("cur_index"), &temp)); - cur_index_ = static_cast(temp); - if (cur_index_ != temp) { - return errors::Internal("Invalid value for cur_index ", temp); - } - } - if (!reader_.Valid()) { - return errors::Internal("Error initializing BundleReader."); - } - reader_.Seek(dataset()->FormatName(cur_index_, 0)); - iterator_restored_ = true; - return Status::OK(); - } - - private: - mutex mu_; - size_t cur_index_ GUARDED_BY(mu_); - BundleReader reader_ GUARDED_BY(mu_); - bool iterator_restored_ GUARDED_BY(mu_); - }; // FileReaderIterator - - void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // We intentionally use the same prefix for both `FileReaderIterator` - // and `FileWriterIterator`. Since at any time there will be at most - // one of them alive, there should be no conflicts. This allows both - // iterators to use a common key for `cur_index`. We leverage this - // in the corner case when this iterator is restored from an old - // checkpoint in `write` mode and the cache has been completely - // flushed to disk since then. In that case we simply build a - // `FileReaderIterator` and seek to the `cur_index`. - switch (mode_) { - case Mode::read: - iterator_ = absl::make_unique( - FileReaderIterator::Params{dataset(), - strings::StrCat(prefix(), "Impl")}); - break; - case Mode::write: - iterator_ = absl::make_unique( - FileWriterIterator::Params{dataset(), - strings::StrCat(prefix(), "Impl")}); - } - } - - mutex mu_; - enum Mode { read, write }; - Mode mode_ GUARDED_BY(mu_); - std::unique_ptr iterator_ GUARDED_BY(mu_); - }; // FileIterator - - const DatasetBase* const input_; - const string filename_; - Env* const env_; - const size_t num_tensors_; - const size_t tensor_index_padding_size_; - static const size_t kMaxItems = 10000000; // 10 million - const size_t item_index_padding_size_; - const string tensor_format_string_; - }; // FileDataset - - class MemoryDataset : public DatasetBase { - public: - explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input) - : DatasetBase(DatasetContext(ctx)), input_(input) { - input->Ref(); - } - - ~MemoryDataset() override { input_->Unref(); } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique(MemoryIterator::Params{ - this, strings::StrCat(prefix, "::MemoryCache")}); - } - - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); - } - - const std::vector& output_shapes() const override { - return input_->output_shapes(); - } - - string DebugString() const override { - return "CacheDatasetOp::MemoryDataset"; - } - - int64 Cardinality() const override { return input_->Cardinality(); } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); - Node* filename_node = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_node, filename_node}, output)); - return Status::OK(); - } - - private: - // A thread-safe data structure for caching dataset elements. - // - // The expected use is that a single `MemoryWriterIterator` populates the - // cache with dataset elements. Once all elements are cached, the cache can - // be used by one or more `MemoryReaderIterator`s. - class MemoryCache : public ResourceBase { - public: - MemoryCache() = default; - - string DebugString() const override { - return "CacheDataset::MemoryCache"; - } - - // Marks the cache as completed. - void Complete() { - mutex_lock l(mu_); - completed_ = true; - } - - // Returns whether the cache is claimed. - bool IsClaimed() { - tf_shared_lock l(mu_); - return claimed_; - } - - // Returns whether the cache is completed. - bool IsCompleted() { - tf_shared_lock l(mu_); - return completed_; - } - - // Attempts to claim the cache, returning whether the cache was claimed. - bool MaybeClaim() { - mutex_lock l(mu_); - if (!claimed_) { - claimed_ = true; - return true; - } - return false; - } - - // Resets the cache. - void Reset() { - mutex_lock l(mu_); - claimed_ = false; - completed_ = false; - cache_.clear(); - } - - // Returns the element at the given index. - const std::vector& at(int64 index) { - tf_shared_lock l(mu_); - DCHECK(index < cache_.size()); - return cache_[index]; - } - - // Adds the element to the cache. - void emplace_back(std::vector element) { - mutex_lock l(mu_); - cache_.emplace_back(std::move(element)); - } - - // Returns the size of the cache. - size_t size() { - tf_shared_lock l(mu_); - return cache_.size(); + filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_); + lockfile_ = strings::StrCat(filename_, kLockFileSuffix); + writer_ = absl::make_unique(dataset()->env_, filename_); + return Status::OK(); } private: - mutex mu_; - // Determines whether a writer has claimed the cache. - bool claimed_ GUARDED_BY(mu_) = false; - // Determines whether all elements of the dataset have been cached. - bool completed_ GUARDED_BY(mu_) = false; - std::vector> cache_ GUARDED_BY(mu_); - }; - - class MemoryIterator : public DatasetIterator { - public: - explicit MemoryIterator(const Params& params) - : DatasetIterator(params) {} - - ~MemoryIterator() override { cache_->Unref(); } - - Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - // Use the resource manager in the iterator context to get / create - // a cache. - ResourceMgr* mgr = ctx->resource_mgr(); - const string name = strings::StrCat( - prefix(), "::", dataset()->node_name(), "::MemoryCache"); - TF_RETURN_IF_ERROR(mgr->LookupOrCreate( - "tf_data", name, &cache_, [](MemoryCache** cache) { - *cache = new MemoryCache(); - return Status::OK(); - })); - mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read; - InitializeIterator(); - if (mode_ == Mode::read && !cache_->IsCompleted()) { - return errors::Internal( - "Cache should only be read after it has been completed."); + Status EnsureLockFileExists(bool* end_of_sequence) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (iteration_completed_) { + *end_of_sequence = true; + return Status::OK(); } - return iterator_->Initialize(ctx); + if (lockfile_created_ && !iteration_completed_) return Status::OK(); + + // Perform rudimentary locking to help catch concurrent writes to the + // same cache files. + + // 1. Check that a checkpoint for the shard has not already been + // written. + if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) { + return errors::AlreadyExists("Existing cache files found: \n", + MetaFilename(filename_), "\n", + DataFilename(filename_, 0, 1), "\n", + "To continue delete the above files."); + } + + // 2. Check that there isn't a concurrent iterator that is writing + // to cache. + if (dataset()->env_->FileExists(lockfile_).ok()) { + // Attempt to read the contents of the lockfile. + char contents_scratch[151] = {0}; // Initialize all to 0. + StringPiece contents; + std::unique_ptr file; + if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) { + file->Read(0, 150, &contents, contents_scratch).IgnoreError(); + } + return errors::AlreadyExists( + "There appears to be a concurrent caching iterator running - " + "cache lockfile already exists ('", + lockfile_, + "'). If you are sure no other running TF computations are " + "using this cache prefix, delete the lockfile and " + "re-initialize the iterator. Lockfile contents: ", + contents); + } + // Create the file, and write some basic contents. + std::unique_ptr lockfile; + TF_RETURN_IF_ERROR( + dataset()->env_->NewWritableFile(lockfile_, &lockfile)); + TF_RETURN_IF_ERROR(lockfile->Append( + strings::StrCat(kCreatedAt, ": ", dataset()->env_->NowSeconds()))); + + // At this point we know that + // 1. There is no conflicting checkpoint with prefix `filename_`. + // 2. There is no concurrent session that is trying to write a ckpt + // to filename. + // So it is safe to create a BundleWriter here. Note that it is + // unsafe to initialize the BundleWriter anywhere the above + // conditions are not met since BundleWriter's constructor creates + // new temp files which can delete the temp files created by a + // BundleWriter in another Session. + writer_ = absl::make_unique(dataset()->env_, filename_); + lockfile_created_ = true; + return Status::OK(); } + Status Finish() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + iteration_completed_ = true; + // Flush the current bundle. + TF_RETURN_IF_ERROR(writer_->Finish()); + // Merge all the bundles. + // Currently there are `shard_id_ + 1` bundles, one for each + // checkpoint. Each bundle has prefix _ where `id` is an + // integer starting at 0 an incremented by 1 for each new checkpoint. + // We merge all these bundles into a bundle with prefix so + // that the next call to `MakeIterator` can build a + // `FileReaderIterator`. + { + std::vector prefixes; + prefixes.reserve(shard_id_ + 1); + for (size_t i = 0; i <= shard_id_; ++i) { + prefixes.emplace_back( + strings::StrCat(dataset()->filename_, "_", i)); + } + TF_RETURN_IF_ERROR( + MergeBundles(dataset()->env_, prefixes, dataset()->filename_)); + } + // Delete all lockfiles. + for (size_t i = 0; i <= shard_id_; ++i) { + TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile( + strings::StrCat(dataset()->filename_, "_", i, kLockFileSuffix))); + } + return Status::OK(); + } + + mutex mu_; + size_t cur_index_ GUARDED_BY(mu_); + // Index of the current shard. This gets incremented whenever a new + // cache shard is saved. + size_t shard_id_ GUARDED_BY(mu_); + std::unique_ptr input_impl_ GUARDED_BY(mu_); + // The current prefix for the cache file. This is equal to + // `StrCat(dataset()->filename_, "_", shard_id_)`. + string filename_; + std::unique_ptr writer_ GUARDED_BY(mu_); + string lockfile_ GUARDED_BY(mu_); + bool lockfile_created_ GUARDED_BY(mu_); + bool iteration_completed_ GUARDED_BY(mu_); + }; // FileWriterIterator + + class FileReaderIterator : public DatasetIterator { + public: + explicit FileReaderIterator(const Params& params) + : DatasetIterator(params), + cur_index_(0), + reader_(dataset()->env_, dataset()->filename_), + iterator_restored_(false) {} + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - return iterator_->GetNext(ctx, out_tensors, end_of_sequence); + *end_of_sequence = false; + TF_RETURN_IF_ERROR(reader_.status()); + if (!reader_.Valid()) { + *end_of_sequence = true; + return Status::OK(); + } + out_tensors->clear(); + out_tensors->resize(dataset()->num_tensors_); + + for (size_t i = 0; i < dataset()->num_tensors_; ++i) { + // When the iterator is restored from the checkpoint, `reader_` is + // already pointing at `key` so we do not need to skip the header + // entry. + if (!iterator_restored_) { + reader_.Next(); // The first entry in the table is a header. + } else { + iterator_restored_ = false; + } + if (!reader_.Valid()) { + out_tensors->clear(); + *end_of_sequence = true; + return Status::OK(); + } + StringPiece key = reader_.key(); + DCHECK_EQ(key, dataset()->FormatName(cur_index_, i)); + TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i])); + TF_RETURN_IF_ERROR(reader_.status()); + } + cur_index_++; + return Status::OK(); } protected: @@ -728,240 +499,494 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_)); - if (cache_->IsClaimed()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kCurIndex), cur_index_)); + return Status::OK(); + } + + Status RestoreInternal( + IteratorContext* ctx, + IteratorStateReader* iterator_state_reader) override { + mutex_lock l(mu_); + { + // TODO(b/78048575): Update this when saving size_t tensors directly + // is supported. + int64 temp; TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("cache_claimed"), "")); - size_t cache_size = cache_->size(); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("cache_size"), cache_size)); - for (size_t i = 0; i < cache_size; i++) { - auto& element = cache_->at(i); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("cache[", i, "].size")), - element.size())); - for (size_t j = 0; j < element.size(); ++j) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat("cache[", i, "][", j, "]")), - element[j])); - } - } - if (cache_->IsCompleted()) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("cache_completed"), "")); + iterator_state_reader->ReadScalar(full_name(kCurIndex), &temp)); + cur_index_ = static_cast(temp); + if (cur_index_ != temp) { + return errors::Internal("Invalid value for cur_index ", temp); } } - return SaveInput(writer, iterator_); + if (!reader_.Valid()) { + return errors::Internal("Error initializing BundleReader."); + } + reader_.Seek(dataset()->FormatName(cur_index_, 0)); + iterator_restored_ = true; + return Status::OK(); + } + + private: + mutex mu_; + size_t cur_index_ GUARDED_BY(mu_); + BundleReader reader_ GUARDED_BY(mu_); + bool iterator_restored_ GUARDED_BY(mu_); + }; // FileReaderIterator + + void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // We intentionally use the same prefix for both `FileReaderIterator` + // and `FileWriterIterator`. Since at any time there will be at most + // one of them alive, there should be no conflicts. This allows both + // iterators to use a common key for `cur_index`. We leverage this + // in the corner case when this iterator is restored from an old + // checkpoint in `write` mode and the cache has been completely + // flushed to disk since then. In that case we simply build a + // `FileReaderIterator` and seek to the `cur_index`. + switch (mode_) { + case Mode::read: + iterator_ = + absl::make_unique(FileReaderIterator::Params{ + dataset(), strings::StrCat(prefix(), kImpl)}); + break; + case Mode::write: + iterator_ = + absl::make_unique(FileWriterIterator::Params{ + dataset(), strings::StrCat(prefix(), kImpl)}); + } + } + + mutex mu_; + enum Mode { read, write }; + Mode mode_ GUARDED_BY(mu_); + std::unique_ptr iterator_ GUARDED_BY(mu_); + }; // FileIterator + + const DatasetBase* const input_; + const string filename_; + Env* const env_; + const size_t num_tensors_; + const size_t tensor_index_padding_size_; + static const size_t kMaxItems = 10000000; // 10 million + const size_t item_index_padding_size_; + const string tensor_format_string_; +}; // FileDataset + +class CacheDatasetOp::MemoryDataset : public DatasetBase { + public: + explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), input_(input) { + input->Ref(); + } + + ~MemoryDataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + name_utils::IteratorPrefixParams params; + params.dataset_prefix = kMemoryDatasetPrefix; + return absl::make_unique(MemoryIterator::Params{ + this, name_utils::IteratorPrefix(kDatasetType, prefix, params)}); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + name_utils::DatasetDebugStringParams params; + params.dataset_prefix = kMemoryDatasetPrefix; + return name_utils::DatasetDebugString(kDatasetType, params); + } + + int64 Cardinality() const override { return input_->Cardinality(); } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); + Node* filename_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_node, filename_node}, output)); + return Status::OK(); + } + + private: + // A thread-safe data structure for caching dataset elements. + // + // The expected use is that a single `MemoryWriterIterator` populates the + // cache with dataset elements. Once all elements are cached, the cache can + // be used by one or more `MemoryReaderIterator`s. + class MemoryCache : public ResourceBase { + public: + MemoryCache() = default; + + string DebugString() const override { return "CacheDataset::MemoryCache"; } + + // Marks the cache as completed. + void Complete() { + mutex_lock l(mu_); + completed_ = true; + } + + // Returns whether the cache is claimed. + bool IsClaimed() { + tf_shared_lock l(mu_); + return claimed_; + } + + // Returns whether the cache is completed. + bool IsCompleted() { + tf_shared_lock l(mu_); + return completed_; + } + + // Attempts to claim the cache, returning whether the cache was claimed. + bool MaybeClaim() { + mutex_lock l(mu_); + if (!claimed_) { + claimed_ = true; + return true; + } + return false; + } + + // Resets the cache. + void Reset() { + mutex_lock l(mu_); + claimed_ = false; + completed_ = false; + cache_.clear(); + } + + // Returns the element at the given index. + const std::vector& at(int64 index) { + tf_shared_lock l(mu_); + DCHECK(index < cache_.size()); + return cache_[index]; + } + + // Adds the element to the cache. + void emplace_back(std::vector element) { + mutex_lock l(mu_); + cache_.emplace_back(std::move(element)); + } + + // Returns the size of the cache. + size_t size() { + tf_shared_lock l(mu_); + return cache_.size(); + } + + private: + mutex mu_; + // Determines whether a writer has claimed the cache. + bool claimed_ GUARDED_BY(mu_) = false; + // Determines whether all elements of the dataset have been cached. + bool completed_ GUARDED_BY(mu_) = false; + std::vector> cache_ GUARDED_BY(mu_); + }; + + class MemoryIterator : public DatasetIterator { + public: + explicit MemoryIterator(const Params& params) + : DatasetIterator(params) {} + + ~MemoryIterator() override { cache_->Unref(); } + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + // Use the resource manager in the iterator context to get / create + // a cache. + ResourceMgr* mgr = ctx->resource_mgr(); + const string name = strings::StrCat(prefix(), name_utils::kDelimiter, + dataset()->node_name(), + name_utils::kDelimiter, kMemoryCache); + TF_RETURN_IF_ERROR(mgr->LookupOrCreate( + kTFData, name, &cache_, [](MemoryCache** cache) { + *cache = new MemoryCache(); + return Status::OK(); + })); + mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read; + InitializeIterator(); + if (mode_ == Mode::read && !cache_->IsCompleted()) { + return errors::Internal( + "Cache should only be read after it has been completed."); + } + return iterator_->Initialize(ctx); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + return iterator_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_)); + if (cache_->IsClaimed()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheClaimed), "")); + size_t cache_size = cache_->size(); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kCacheSize), cache_size)); + for (size_t i = 0; i < cache_size; i++) { + auto& element = cache_->at(i); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), + element.size())); + for (size_t j = 0; j < element.size(); ++j) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat(kCache, "[", i, "][", j, "]")), + element[j])); + } + } + if (cache_->IsCompleted()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kCacheCompleted), "")); + } + } + return SaveInput(writer, iterator_); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + iterator_.reset(); + cache_->Reset(); + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp)); + mode_ = static_cast(temp); + } + if (reader->Contains(full_name(kCacheClaimed))) { + CHECK(cache_->MaybeClaim()); + size_t cache_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp)); + cache_size = static_cast(temp); + } + for (size_t i = 0; i < cache_size; ++i) { + std::vector element; + size_t element_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), + &temp)); + element_size = static_cast(temp); + } + element.reserve(element_size); + for (size_t j = 0; j < element_size; ++j) { + element.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat(kCache, "[", i, "][", j, "]")), + &element.back())); + } + cache_->emplace_back(std::move(element)); + } + if (reader->Contains(full_name(kCacheCompleted))) { + cache_->Complete(); + } + } + InitializeIterator(); + TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); + return RestoreInput(ctx, reader, iterator_); + } + + private: + class MemoryWriterIterator : public DatasetIterator { + public: + explicit MemoryWriterIterator(const Params& params, MemoryCache* cache) + : DatasetIterator(params), cache_(cache) { + CHECK(cache_); + } + + ~MemoryWriterIterator() override { + mutex_lock l(mu_); + if (cache_->size() > 0 && !cache_->IsCompleted()) { + LOG(WARNING) + << "The calling iterator did not fully read the dataset being " + "cached. In order to avoid unexpected truncation of the " + "dataset, the partially cached contents of the dataset " + "will be discarded. This can happen if you have an input " + "pipeline similar to `dataset.cache().take(k).repeat()`. " + "You should use `dataset.take(k).cache().repeat()` instead."; + cache_->Reset(); + } + } + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (*end_of_sequence) { + cache_->Complete(); + return Status::OK(); + } + RecordBufferEnqueue(ctx, *out_tensors); + cache_->emplace_back(*out_tensors); + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + return SaveInput(writer, input_impl_); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - iterator_.reset(); - cache_->Reset(); - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp)); - mode_ = static_cast(temp); - } - if (reader->Contains(full_name("cache_claimed"))) { - CHECK(cache_->MaybeClaim()); - size_t cache_size; - { - int64 temp; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("cache_size"), &temp)); - cache_size = static_cast(temp); - } - for (size_t i = 0; i < cache_size; ++i) { - std::vector element; - size_t element_size; - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat("cache[", i, "].size")), &temp)); - element_size = static_cast(temp); - } - element.reserve(element_size); - for (size_t j = 0; j < element_size; ++j) { - element.emplace_back(); - TF_RETURN_IF_ERROR(reader->ReadTensor( - full_name(strings::StrCat("cache[", i, "][", j, "]")), - &element.back())); - } - cache_->emplace_back(std::move(element)); - } - if (reader->Contains(full_name("cache_completed"))) { - cache_->Complete(); - } - } - InitializeIterator(); - TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); - return RestoreInput(ctx, reader, iterator_); + return RestoreInput(ctx, reader, input_impl_); } private: - class MemoryWriterIterator : public DatasetIterator { - public: - explicit MemoryWriterIterator(const Params& params, MemoryCache* cache) - : DatasetIterator(params), cache_(cache) { - CHECK(cache_); - } + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + MemoryCache* const cache_ GUARDED_BY(mu_); // not owned. + }; // MemoryWriterIterator - ~MemoryWriterIterator() override { - mutex_lock l(mu_); - if (cache_->size() > 0 && !cache_->IsCompleted()) { - LOG(WARNING) - << "The calling iterator did not fully read the dataset being " - "cached. In order to avoid unexpected truncation of the " - "dataset, the partially cached contents of the dataset " - "will be discarded. This can happen if you have an input " - "pipeline similar to `dataset.cache().take(k).repeat()`. " - "You should use `dataset.take(k).cache().repeat()` instead."; - cache_->Reset(); - } - } + class MemoryReaderIterator : public DatasetIterator { + public: + explicit MemoryReaderIterator(const Params& params, MemoryCache* cache) + : DatasetIterator(params), cache_(cache), index_(0) { + CHECK(cache); + } - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + Status Initialize(IteratorContext* ctx) override { + // The memory allocated for the cache is owned by the parent + // dataset but performance modeling uses the iterator abstraction and + // thus we record the memory allocated for the cache here. The caveat + // is that this is incorrect if there are concurrent instances of this + // iterator. + tf_shared_lock l(mu_); + for (size_t i = 0; i < cache_->size(); ++i) { + RecordBufferEnqueue(ctx, cache_->at(i)); } + return Status::OK(); + } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (*end_of_sequence) { - cache_->Complete(); - return Status::OK(); - } - RecordBufferEnqueue(ctx, *out_tensors); - cache_->emplace_back(*out_tensors); + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (index_ < cache_->size()) { + const std::vector& cache_tensors = cache_->at(index_); + out_tensors->insert(out_tensors->begin(), cache_tensors.begin(), + cache_tensors.end()); + index_++; + *end_of_sequence = false; return Status::OK(); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - return SaveInput(writer, input_impl_); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - return RestoreInput(ctx, reader, input_impl_); - } - - private: - mutex mu_; - std::unique_ptr input_impl_ GUARDED_BY(mu_); - MemoryCache* const cache_ GUARDED_BY(mu_); // not owned. - }; // MemoryWriterIterator - - class MemoryReaderIterator : public DatasetIterator { - public: - explicit MemoryReaderIterator(const Params& params, MemoryCache* cache) - : DatasetIterator(params), cache_(cache), index_(0) { - CHECK(cache); - } - - Status Initialize(IteratorContext* ctx) override { - // The memory allocated for the cache is owned by the parent - // dataset but performance modeling uses the iterator abstraction and - // thus we record the memory allocated for the cache here. The caveat - // is that this is incorrect if there are concurrent instances of this - // iterator. - tf_shared_lock l(mu_); - for (size_t i = 0; i < cache_->size(); ++i) { - RecordBufferEnqueue(ctx, cache_->at(i)); - } + } else { + *end_of_sequence = true; return Status::OK(); } - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - if (index_ < cache_->size()) { - const std::vector& cache_tensors = cache_->at(index_); - out_tensors->insert(out_tensors->begin(), cache_tensors.begin(), - cache_tensors.end()); - index_++; - *end_of_sequence = false; - return Status::OK(); - } else { - *end_of_sequence = true; - return Status::OK(); - } - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("index"), index_)); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("index"), &temp)); - index_ = static_cast(temp); - } - return Status::OK(); - } - - private: - mutex mu_; - MemoryCache* const cache_ GUARDED_BY(mu_); // not owned. - size_t index_ GUARDED_BY(mu_); - }; // MemoryReaderIterator - - void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - switch (mode_) { - case Mode::read: - iterator_ = absl::make_unique( - MemoryReaderIterator::Params{dataset(), - strings::StrCat(prefix(), "Impl")}, - cache_); - break; - case Mode::write: - iterator_ = absl::make_unique( - MemoryWriterIterator::Params{dataset(), - strings::StrCat(prefix(), "Impl")}, - cache_); - } } + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &temp)); + index_ = static_cast(temp); + } + return Status::OK(); + } + + private: mutex mu_; - MemoryCache* cache_ GUARDED_BY(mu_); // not owned. - enum Mode { read, write }; - Mode mode_ GUARDED_BY(mu_); - std::unique_ptr iterator_ GUARDED_BY(mu_); - }; // MemoryIterator + MemoryCache* const cache_ GUARDED_BY(mu_); // not owned. + size_t index_ GUARDED_BY(mu_); + }; // MemoryReaderIterator - const DatasetBase* const input_; - }; // MemoryDataset -}; // CacheDatasetOp + void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + switch (mode_) { + case Mode::read: + iterator_ = absl::make_unique( + MemoryReaderIterator::Params{dataset(), + strings::StrCat(prefix(), kImpl)}, + cache_); + break; + case Mode::write: + iterator_ = absl::make_unique( + MemoryWriterIterator::Params{dataset(), + strings::StrCat(prefix(), kImpl)}, + cache_); + } + } + mutex mu_; + MemoryCache* cache_ GUARDED_BY(mu_); // not owned. + enum Mode { read, write }; + Mode mode_ GUARDED_BY(mu_); + std::unique_ptr iterator_ GUARDED_BY(mu_); + }; // MemoryIterator + + const DatasetBase* const input_; +}; // MemoryDataset + +CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + +void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) { + // Parse out the filenames tensor. + string filename; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kFileName, &filename)); + + if (filename.empty()) { + *output = new MemoryDataset(ctx, input); + } else { + *output = new FileDataset(ctx, input, filename, ctx->env()); + } +} + +namespace { REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU), CacheDatasetOp); - } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.h b/tensorflow/core/kernels/data/cache_dataset_ops.h new file mode 100644 index 00000000000..af023a60075 --- /dev/null +++ b/tensorflow/core/kernels/data/cache_dataset_ops.h @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class CacheDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Cache"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kFileName = "filename"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit CacheDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class FileDataset; + class MemoryDataset; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc new file mode 100644 index 00000000000..812d719946f --- /dev/null +++ b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc @@ -0,0 +1,533 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/cache_dataset_ops.h" + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "cache_dataset"; +constexpr char kIteratorPrefix[] = "Iterator"; +constexpr char kFileDatasetPrefix[] = "File"; +constexpr char kMemoryDatasetPrefix[] = "Memory"; + +class CacheDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector* const tensor_vector, Tensor* dataset_tensor) { + DatasetBase* tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Create a new `CacheDataset` op kernel. + Status CreateCacheDatasetOpKernel( + const DataTypeVector& output_types, + const std::vector& output_shapes, + std::unique_ptr* cache_dataset_op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, name_utils::OpName(CacheDatasetOp::kDatasetType), + {CacheDatasetOp::kInputDataset, CacheDatasetOp::kFileName}, + {{CacheDatasetOp::kOutputTypes, output_types}, + {CacheDatasetOp::kOutputShapes, output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, cache_dataset_op_kernel)); + return Status::OK(); + } + + // Create a new `CacheDataset` op kernel context. + Status CreateCacheDatasetContext( + OpKernel* const op_kernel, + gtl::InlinedVector* const inputs, + std::unique_ptr* context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + string file_name; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +// Test case 1: cache data in file. +TestCase TestCase1() { + return { + /*input_tensors*/ {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*file_name*/ absl::StrCat(testing::TmpDir(), "/cache_data"), + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, {0, 1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, {3, 4, 5}), + DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, {6, 7, 8})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({3, 1})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 4, 11}}; +} + +// Test case 2: cache empty data in file. +TestCase TestCase2() { + return {/*input_tensors*/ { + DatasetOpsTestBase::CreateTensor(TensorShape{0}, {})}, + /*file_name*/ absl::StrCat(testing::TmpDir(), "/empty_cache_data"), + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 4, 11}}; +} + +// Test case 3: cache data in memory. +TestCase TestCase3() { + return { + /*input_tensors*/ {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*file_name*/ "", + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, {0, 1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, {3, 4, 5}), + DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, {6, 7, 8})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({3, 1})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 4, 11}}; +} + +// Test case 4: cache empty data in memory. +TestCase TestCase4() { + return {/*input_tensors*/ { + DatasetOpsTestBase::CreateTensor(TensorShape{0}, {})}, + /*file_name*/ "", + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 4, 11}}; +} + +class ParameterizedCacheDatasetOpTest + : public CacheDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedCacheDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(cache_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(cache_dataset->MakeIterator(iterator_ctx.get(), kIteratorPrefix, + &iterator)); + + // Test the write mode. + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + } + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); + + // Test the read mode. + TF_ASSERT_OK(cache_dataset->MakeIterator(iterator_ctx.get(), kIteratorPrefix, + &iterator)); + end_of_sequence = false; + out_tensors.clear(); + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + } + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +TEST_F(CacheDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + EXPECT_EQ(cache_dataset->node_name(), kNodeName); +} + +TEST_P(ParameterizedCacheDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + EXPECT_EQ(cache_dataset->type_string(), + name_utils::OpName(CacheDatasetOp::kDatasetType)); +} + +TEST_P(ParameterizedCacheDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(cache_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedCacheDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(cache_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedCacheDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + EXPECT_EQ(cache_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_P(ParameterizedCacheDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + std::unique_ptr serialization_context; + TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(cache_dataset->Save(serialization_context.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedCacheDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(cache_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(cache_dataset->MakeIterator(iterator_ctx.get(), kIteratorPrefix, + &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedCacheDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(cache_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(cache_dataset->MakeIterator(iterator_ctx.get(), kIteratorPrefix, + &iterator)); + + name_utils::IteratorPrefixParams params; + params.dataset_prefix = + test_case.file_name.empty() ? kMemoryDatasetPrefix : kFileDatasetPrefix; + EXPECT_EQ(iterator->prefix(), + name_utils::IteratorPrefix(CacheDatasetOp::kDatasetType, + kIteratorPrefix, params)); +} + +TEST_P(ParameterizedCacheDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr cache_dataset_kernel; + TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &cache_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor file_name = CreateTensor(TensorShape{}, {test_case.file_name}); + gtl::InlinedVector inputs( + {TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)}); + std::unique_ptr cache_dataset_context; + TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs, + &cache_dataset_context)); + DatasetBase* cache_dataset; + TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(), + cache_dataset_context.get(), &cache_dataset)); + core::ScopedUnref scoped_unref(cache_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(cache_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(cache_dataset->MakeIterator(iterator_ctx.get(), kIteratorPrefix, + &iterator)); + + bool end_of_sequence = false; + std::vector out_tensors; + // For MemoryIterator in the read mode, the cache needs to be completed before + // it has been read. + if (test_case.file_name.empty()) { + while (!end_of_sequence) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + } + end_of_sequence = false; + out_tensors.clear(); + TF_ASSERT_OK(cache_dataset->MakeIterator(iterator_ctx.get(), + kIteratorPrefix, &iterator)); + } + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + for (int breakpoint : test_case.breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, kIteratorPrefix, + *cache_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + out_tensors.clear(); + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_LT(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_cardinality) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + CacheDatasetOpTest, ParameterizedCacheDatasetOpTest, + ::testing::ValuesIn(std::vector({TestCase1(), TestCase2(), + TestCase3(), TestCase4()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/name_utils.cc b/tensorflow/core/kernels/data/name_utils.cc index 391f45014c8..b6404892fdb 100644 --- a/tensorflow/core/kernels/data/name_utils.cc +++ b/tensorflow/core/kernels/data/name_utils.cc @@ -65,10 +65,11 @@ string IteratorPrefix(const string& dataset_type, const string& prefix) { string IteratorPrefix(const string& dataset_type, const string& prefix, const IteratorPrefixParams& params) { if (params.op_version == 1) { - return strings::StrCat(prefix, kDelimiter, dataset_type); + return strings::StrCat(prefix, kDelimiter, params.dataset_prefix, + dataset_type); } - return strings::StrCat(prefix, kDelimiter, dataset_type, kVersion, - params.op_version); + return strings::StrCat(prefix, kDelimiter, params.dataset_prefix, + dataset_type, kVersion, params.op_version); } } // namespace name_utils diff --git a/tensorflow/core/kernels/data/name_utils.h b/tensorflow/core/kernels/data/name_utils.h index 0efa825ec5e..5171b8e05e3 100644 --- a/tensorflow/core/kernels/data/name_utils.h +++ b/tensorflow/core/kernels/data/name_utils.h @@ -44,6 +44,7 @@ struct DatasetDebugStringParams { struct IteratorPrefixParams { int op_version = 1; + string dataset_prefix = ""; }; // Merge the given args in the format of "(arg1, arg2, ..., argn)".