diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc index 6c4d6424146..877d05ebb3f 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc @@ -62,7 +62,7 @@ Status Writer::Create(Env* env, const std::string& filename, } Status Writer::Initialize(tensorflow::Env* env) { - TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_)); + TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_)); #if defined(IS_SLIM_BUILD) if (compression_type_ != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -228,13 +228,14 @@ class Reader::Dataset : public DatasetBase { explicit Dataset(const std::string& filename, const std::string& compression, const int64 version, const DataTypeVector& dtypes, const std::vector& shapes, - DatasetContext::Params params) + const int64 start_index, DatasetContext::Params params) : DatasetBase(DatasetContext(std::move(params))), filename_(filename), compression_(compression), version_(version), dtypes_(dtypes), - shapes_(shapes) {} + shapes_(shapes), + start_index_(start_index) {} const DataTypeVector& output_dtypes() const override { return dtypes_; } @@ -268,6 +269,7 @@ class Reader::Dataset : public DatasetBase { int64 version_; DataTypeVector dtypes_; std::vector shapes_; + const int64 start_index_; class Iterator : public DatasetIterator { public: @@ -275,9 +277,10 @@ class Reader::Dataset : public DatasetBase { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return Reader::Create(ctx->env(), dataset()->filename_, - dataset()->compression_, dataset()->version_, - dataset()->dtypes_, &reader_); + TF_RETURN_IF_ERROR(Reader::Create( + ctx->env(), dataset()->filename_, dataset()->compression_, + dataset()->version_, dataset()->dtypes_, &reader_)); + return reader_->SkipRecords(dataset()->start_index_); } protected: @@ -397,17 +400,32 @@ Status Reader::MakeNestedDataset(Env* env, const string& compression_type, int version, const DataTypeVector& dtypes, const std::vector& shapes, + const int64 start_index, DatasetBase** output) { std::vector datasets; datasets.reserve(filenames.size()); for (const auto& filename : filenames) { + // TODO(frankchn): The reading pattern could be controlled in a non-round + // robin fashion, so we cannot assume a round-robin manner when restoring. + int64 dataset_start_index = start_index / filenames.size(); + if (start_index % filenames.size() > datasets.size()) { + dataset_start_index++; + } + datasets.push_back( new Dataset(filename, compression_type, version, dtypes, shapes, + dataset_start_index, DatasetContext::Params({"snapshot_util::Reader::Dataset", "snapshot_util_reader_Dataset"}))); } + // Rotate the vector such that the first dataset contains the next element + // to be produced. + std::rotate(datasets.begin(), + datasets.begin() + (start_index % filenames.size()), + datasets.end()); + *output = new NestedDataset( datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset", "snapshot_util_reader_NestedDataset"})); @@ -463,6 +481,15 @@ Status Reader::Initialize(Env* env) { return Status::OK(); } +Status Reader::SkipRecords(int64 num_records) { + // TODO(frankchn): Optimize to not parse the entire Tensor and actually skip. + for (int i = 0; i < num_records; ++i) { + std::vector unused_tensors; + TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors)); + } + return Status::OK(); +} + Status Reader::ReadTensors(std::vector* read_tensors) { profiler::TraceMe activity( [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h index dd15c591a22..79299bb79b4 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.h +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h @@ -130,10 +130,13 @@ class Reader { const string& compression_type, int version, const DataTypeVector& dtypes, const std::vector& shapes, + const int64 start_index, DatasetBase** output); Status ReadTensors(std::vector* read_tensors); + Status SkipRecords(int64 num_records); + private: explicit Reader(const std::string& filename, const string& compression_type, int version, const DataTypeVector& dtypes);