[tf.data] Refactor snapshot_util::Reader to own a RandomAccessFile

PiperOrigin-RevId: 306977814
Change-Id: I78140f6e01ac457ffd45c2c083e6d87ed60cccec
This commit is contained in:
Frank Chen 2020-04-16 21:03:53 -07:00 committed by TensorFlower Gardener
parent b5f7f775aa
commit 9b24e4fa8b
3 changed files with 45 additions and 22 deletions

View File

@ -549,7 +549,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (!background_threads_started_) { if (!background_threads_started_) {
for (int i = 0; i < dataset()->num_reader_threads_; ++i) { for (int i = 0; i < dataset()->num_reader_threads_; ++i) {
++num_active_threads_; ++num_active_threads_;
thread_pool_->Schedule([this, i]() { ReadingFilesLoop(i); }); thread_pool_->Schedule(
[this, i, env = ctx->env()]() { ReadingFilesLoop(env, i); });
} }
background_threads_started_ = true; background_threads_started_ = true;
} }
@ -731,13 +732,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
private: private:
// Reads one file end to end. // Reads one file end to end.
Status ReadFile(const string& filename) { Status ReadFile(Env* env, const string& filename) {
std::unique_ptr<RandomAccessFile> file; std::unique_ptr<snapshot_util::Reader> reader;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(snapshot_util::Reader::Create(
Env::Default()->NewRandomAccessFile(filename, &file)); Env::Default(), filename, dataset()->compression_, version_,
snapshot_util::Reader reader(file.get(), dataset()->compression_, dataset()->output_dtypes(), &reader));
version_, dataset()->output_dtypes());
while (true) { while (true) {
// Wait for a slot in the buffer. // Wait for a slot in the buffer.
{ {
@ -754,7 +753,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
} }
std::vector<Tensor> read_tensors; std::vector<Tensor> read_tensors;
Status s = reader.ReadTensors(&read_tensors); Status s = reader->ReadTensors(&read_tensors);
if (s.ok()) { if (s.ok()) {
profiler::TraceMe activity( profiler::TraceMe activity(
[&]() { return absl::StrCat(prefix(), kSeparator, kParse); }, [&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
@ -787,7 +786,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
// Pulls one file off the filenames_ list and reads it through. When // Pulls one file off the filenames_ list and reads it through. When
// all files are read, terminates. // all files are read, terminates.
void ReadingFilesLoop(int i) { void ReadingFilesLoop(Env* env, int i) {
auto cleanup = gtl::MakeCleanup([this]() { auto cleanup = gtl::MakeCleanup([this]() {
mutex_lock l(mu_); mutex_lock l(mu_);
--num_active_threads_; --num_active_threads_;
@ -803,7 +802,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
VLOG(2) << "Starting to read: " << filename; VLOG(2) << "Starting to read: " << filename;
} }
Status s = ReadFile(filename); Status s = ReadFile(env, filename);
// If we get to the end of the file, it's a clean termination and // If we get to the end of the file, it's a clean termination and
// we are at the end of the file. If all files have been processed, // we are at the end of the file. If all files have been processed,
// then we insert an end_of_sequence marker in the buffer and // then we insert an end_of_sequence marker in the buffer and

View File

@ -209,13 +209,27 @@ Status Writer::WriteRecord(const absl::Cord& data) {
} }
#endif // PLATFORM_GOOGLE #endif // PLATFORM_GOOGLE
Reader::Reader(RandomAccessFile* file, const string& compression_type, Status Reader::Create(Env* env, const std::string& filename,
const string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Reader>* out_reader) {
*out_reader =
absl::WrapUnique(new Reader(filename, compression_type, version, dtypes));
return (*out_reader)->Initialize(env);
}
Reader::Reader(const std::string& filename, const string& compression_type,
int version, const DataTypeVector& dtypes) int version, const DataTypeVector& dtypes)
: file_(file), : filename_(filename),
input_stream_(new io::RandomAccessInputStream(file)),
compression_type_(compression_type), compression_type_(compression_type),
version_(version), version_(version),
dtypes_(dtypes) { dtypes_(dtypes) {}
Status Reader::Initialize(Env* env) {
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
#if defined(IS_SLIM_BUILD) #if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) { if (compression_type_ != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
@ -232,16 +246,16 @@ Reader::Reader(RandomAccessFile* file, const string& compression_type,
} else if (compression_type_ == io::compression::kSnappy) { } else if (compression_type_ == io::compression::kSnappy) {
if (version_ == 0) { if (version_ == 0) {
input_stream_ = absl::make_unique<io::SnappyInputBuffer>( input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
} else { } else {
input_stream_ = input_stream_ =
absl::make_unique<io::BufferedInputStream>(file_, 64 << 20); absl::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20);
} }
} }
#endif // IS_SLIM_BUILD #endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes.size()); simple_tensor_mask_.reserve(dtypes_.size());
for (const auto& dtype : dtypes) { for (const auto& dtype : dtypes_) {
if (DataTypeCanUseMemcpy(dtype)) { if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true); simple_tensor_mask_.push_back(true);
num_simple_++; num_simple_++;
@ -250,6 +264,8 @@ Reader::Reader(RandomAccessFile* file, const string& compression_type,
num_complex_++; num_complex_++;
} }
} }
return Status::OK();
} }
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) { Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {

View File

@ -114,12 +114,19 @@ class Reader {
static constexpr const char* const kReadCord = "ReadCord"; static constexpr const char* const kReadCord = "ReadCord";
static constexpr const char* const kSeparator = "::"; static constexpr const char* const kSeparator = "::";
explicit Reader(RandomAccessFile* file, const string& compression_type, static Status Create(Env* env, const std::string& filename,
int version, const DataTypeVector& dtypes); const string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Reader>* out_reader);
Status ReadTensors(std::vector<Tensor>* read_tensors); Status ReadTensors(std::vector<Tensor>* read_tensors);
private: private:
explicit Reader(const std::string& filename, const string& compression_type,
int version, const DataTypeVector& dtypes);
Status Initialize(Env* env);
Status ReadTensorsV0(std::vector<Tensor>* read_tensors); Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
Status SnappyUncompress( Status SnappyUncompress(
@ -134,7 +141,8 @@ class Reader {
Status ReadRecord(absl::Cord* record); Status ReadRecord(absl::Cord* record);
#endif #endif
RandomAccessFile* file_; std::string filename_;
std::unique_ptr<RandomAccessFile> file_;
std::unique_ptr<io::InputStreamInterface> input_stream_; std::unique_ptr<io::InputStreamInterface> input_stream_;
const string compression_type_; const string compression_type_;
const int version_; const int version_;