diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index db9984e02f8..2cc602a15ae 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -549,7 +549,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (!background_threads_started_) { for (int i = 0; i < dataset()->num_reader_threads_; ++i) { ++num_active_threads_; - thread_pool_->Schedule([this, i]() { ReadingFilesLoop(i); }); + thread_pool_->Schedule( + [this, i, env = ctx->env()]() { ReadingFilesLoop(env, i); }); } background_threads_started_ = true; } @@ -731,13 +732,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { private: // Reads one file end to end. - Status ReadFile(const string& filename) { - std::unique_ptr file; - TF_RETURN_IF_ERROR( - Env::Default()->NewRandomAccessFile(filename, &file)); - snapshot_util::Reader reader(file.get(), dataset()->compression_, - version_, dataset()->output_dtypes()); - + Status ReadFile(Env* env, const string& filename) { + std::unique_ptr reader; + TF_RETURN_IF_ERROR(snapshot_util::Reader::Create( + Env::Default(), filename, dataset()->compression_, version_, + dataset()->output_dtypes(), &reader)); while (true) { // Wait for a slot in the buffer. { @@ -754,7 +753,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } } std::vector read_tensors; - Status s = reader.ReadTensors(&read_tensors); + Status s = reader->ReadTensors(&read_tensors); if (s.ok()) { profiler::TraceMe activity( [&]() { return absl::StrCat(prefix(), kSeparator, kParse); }, @@ -787,7 +786,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { // Pulls one file off the filenames_ list and reads it through. When // all files are read, terminates. - void ReadingFilesLoop(int i) { + void ReadingFilesLoop(Env* env, int i) { auto cleanup = gtl::MakeCleanup([this]() { mutex_lock l(mu_); --num_active_threads_; @@ -803,7 +802,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } VLOG(2) << "Starting to read: " << filename; } - Status s = ReadFile(filename); + Status s = ReadFile(env, filename); // If we get to the end of the file, it's a clean termination and // we are at the end of the file. If all files have been processed, // then we insert an end_of_sequence marker in the buffer and diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc index ba8336653f4..6ff4476f9be 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc @@ -209,13 +209,27 @@ Status Writer::WriteRecord(const absl::Cord& data) { } #endif // PLATFORM_GOOGLE -Reader::Reader(RandomAccessFile* file, const string& compression_type, +Status Reader::Create(Env* env, const std::string& filename, + const string& compression_type, int version, + const DataTypeVector& dtypes, + std::unique_ptr* out_reader) { + *out_reader = + absl::WrapUnique(new Reader(filename, compression_type, version, dtypes)); + + return (*out_reader)->Initialize(env); +} + +Reader::Reader(const std::string& filename, const string& compression_type, int version, const DataTypeVector& dtypes) - : file_(file), - input_stream_(new io::RandomAccessInputStream(file)), + : filename_(filename), compression_type_(compression_type), version_(version), - dtypes_(dtypes) { + dtypes_(dtypes) {} + +Status Reader::Initialize(Env* env) { + TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_)); + input_stream_ = std::make_unique(file_.get()); + #if defined(IS_SLIM_BUILD) if (compression_type_ != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -232,16 +246,16 @@ Reader::Reader(RandomAccessFile* file, const string& compression_type, } else if (compression_type_ == io::compression::kSnappy) { if (version_ == 0) { input_stream_ = absl::make_unique( - file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, + file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); } else { input_stream_ = - absl::make_unique(file_, 64 << 20); + absl::make_unique(file_.get(), 64 << 20); } } #endif // IS_SLIM_BUILD - simple_tensor_mask_.reserve(dtypes.size()); - for (const auto& dtype : dtypes) { + simple_tensor_mask_.reserve(dtypes_.size()); + for (const auto& dtype : dtypes_) { if (DataTypeCanUseMemcpy(dtype)) { simple_tensor_mask_.push_back(true); num_simple_++; @@ -250,6 +264,8 @@ Reader::Reader(RandomAccessFile* file, const string& compression_type, num_complex_++; } } + + return Status::OK(); } Status Reader::ReadTensors(std::vector* read_tensors) { diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h index e1c6dbeb67b..3816525775b 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.h +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h @@ -114,12 +114,19 @@ class Reader { static constexpr const char* const kReadCord = "ReadCord"; static constexpr const char* const kSeparator = "::"; - explicit Reader(RandomAccessFile* file, const string& compression_type, - int version, const DataTypeVector& dtypes); + static Status Create(Env* env, const std::string& filename, + const string& compression_type, int version, + const DataTypeVector& dtypes, + std::unique_ptr* out_reader); Status ReadTensors(std::vector* read_tensors); private: + explicit Reader(const std::string& filename, const string& compression_type, + int version, const DataTypeVector& dtypes); + + Status Initialize(Env* env); + Status ReadTensorsV0(std::vector* read_tensors); Status SnappyUncompress( @@ -134,7 +141,8 @@ class Reader { Status ReadRecord(absl::Cord* record); #endif - RandomAccessFile* file_; + std::string filename_; + std::unique_ptr file_; std::unique_ptr input_stream_; const string compression_type_; const int version_;