[tf.data] Refactor snapshot_util::Reader to own a RandomAccessFile

PiperOrigin-RevId: 306977814
Change-Id: I78140f6e01ac457ffd45c2c083e6d87ed60cccec
This commit is contained in:
Frank Chen 2020-04-16 21:03:53 -07:00 committed by TensorFlower Gardener
parent b5f7f775aa
commit 9b24e4fa8b
3 changed files with 45 additions and 22 deletions

View File

@ -549,7 +549,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (!background_threads_started_) {
for (int i = 0; i < dataset()->num_reader_threads_; ++i) {
++num_active_threads_;
thread_pool_->Schedule([this, i]() { ReadingFilesLoop(i); });
thread_pool_->Schedule(
[this, i, env = ctx->env()]() { ReadingFilesLoop(env, i); });
}
background_threads_started_ = true;
}
@ -731,13 +732,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
private:
// Reads one file end to end.
Status ReadFile(const string& filename) {
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(
Env::Default()->NewRandomAccessFile(filename, &file));
snapshot_util::Reader reader(file.get(), dataset()->compression_,
version_, dataset()->output_dtypes());
Status ReadFile(Env* env, const string& filename) {
std::unique_ptr<snapshot_util::Reader> reader;
TF_RETURN_IF_ERROR(snapshot_util::Reader::Create(
Env::Default(), filename, dataset()->compression_, version_,
dataset()->output_dtypes(), &reader));
while (true) {
// Wait for a slot in the buffer.
{
@ -754,7 +753,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
}
std::vector<Tensor> read_tensors;
Status s = reader.ReadTensors(&read_tensors);
Status s = reader->ReadTensors(&read_tensors);
if (s.ok()) {
profiler::TraceMe activity(
[&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
@ -787,7 +786,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
// Pulls one file off the filenames_ list and reads it through. When
// all files are read, terminates.
void ReadingFilesLoop(int i) {
void ReadingFilesLoop(Env* env, int i) {
auto cleanup = gtl::MakeCleanup([this]() {
mutex_lock l(mu_);
--num_active_threads_;
@ -803,7 +802,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
VLOG(2) << "Starting to read: " << filename;
}
Status s = ReadFile(filename);
Status s = ReadFile(env, filename);
// If we get to the end of the file, it's a clean termination and
// we are at the end of the file. If all files have been processed,
// then we insert an end_of_sequence marker in the buffer and

View File

@ -209,13 +209,27 @@ Status Writer::WriteRecord(const absl::Cord& data) {
}
#endif // PLATFORM_GOOGLE
Reader::Reader(RandomAccessFile* file, const string& compression_type,
Status Reader::Create(Env* env, const std::string& filename,
const string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Reader>* out_reader) {
*out_reader =
absl::WrapUnique(new Reader(filename, compression_type, version, dtypes));
return (*out_reader)->Initialize(env);
}
Reader::Reader(const std::string& filename, const string& compression_type,
int version, const DataTypeVector& dtypes)
: file_(file),
input_stream_(new io::RandomAccessInputStream(file)),
: filename_(filename),
compression_type_(compression_type),
version_(version),
dtypes_(dtypes) {
dtypes_(dtypes) {}
Status Reader::Initialize(Env* env) {
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
#if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
@ -232,16 +246,16 @@ Reader::Reader(RandomAccessFile* file, const string& compression_type,
} else if (compression_type_ == io::compression::kSnappy) {
if (version_ == 0) {
input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
} else {
input_stream_ =
absl::make_unique<io::BufferedInputStream>(file_, 64 << 20);
absl::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20);
}
}
#endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes.size());
for (const auto& dtype : dtypes) {
simple_tensor_mask_.reserve(dtypes_.size());
for (const auto& dtype : dtypes_) {
if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true);
num_simple_++;
@ -250,6 +264,8 @@ Reader::Reader(RandomAccessFile* file, const string& compression_type,
num_complex_++;
}
}
return Status::OK();
}
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {

View File

@ -114,12 +114,19 @@ class Reader {
static constexpr const char* const kReadCord = "ReadCord";
static constexpr const char* const kSeparator = "::";
explicit Reader(RandomAccessFile* file, const string& compression_type,
int version, const DataTypeVector& dtypes);
static Status Create(Env* env, const std::string& filename,
const string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Reader>* out_reader);
Status ReadTensors(std::vector<Tensor>* read_tensors);
private:
explicit Reader(const std::string& filename, const string& compression_type,
int version, const DataTypeVector& dtypes);
Status Initialize(Env* env);
Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
Status SnappyUncompress(
@ -134,7 +141,8 @@ class Reader {
Status ReadRecord(absl::Cord* record);
#endif
RandomAccessFile* file_;
std::string filename_;
std::unique_ptr<RandomAccessFile> file_;
std::unique_ptr<io::InputStreamInterface> input_stream_;
const string compression_type_;
const int version_;