[tf.data] Refactor snapshot_util::Reader to own a RandomAccessFile
PiperOrigin-RevId: 306977814 Change-Id: I78140f6e01ac457ffd45c2c083e6d87ed60cccec
This commit is contained in:
parent
b5f7f775aa
commit
9b24e4fa8b
@ -549,7 +549,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (!background_threads_started_) {
|
||||
for (int i = 0; i < dataset()->num_reader_threads_; ++i) {
|
||||
++num_active_threads_;
|
||||
thread_pool_->Schedule([this, i]() { ReadingFilesLoop(i); });
|
||||
thread_pool_->Schedule(
|
||||
[this, i, env = ctx->env()]() { ReadingFilesLoop(env, i); });
|
||||
}
|
||||
background_threads_started_ = true;
|
||||
}
|
||||
@ -731,13 +732,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
private:
|
||||
// Reads one file end to end.
|
||||
Status ReadFile(const string& filename) {
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Env::Default()->NewRandomAccessFile(filename, &file));
|
||||
snapshot_util::Reader reader(file.get(), dataset()->compression_,
|
||||
version_, dataset()->output_dtypes());
|
||||
|
||||
Status ReadFile(Env* env, const string& filename) {
|
||||
std::unique_ptr<snapshot_util::Reader> reader;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::Reader::Create(
|
||||
Env::Default(), filename, dataset()->compression_, version_,
|
||||
dataset()->output_dtypes(), &reader));
|
||||
while (true) {
|
||||
// Wait for a slot in the buffer.
|
||||
{
|
||||
@ -754,7 +753,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
}
|
||||
std::vector<Tensor> read_tensors;
|
||||
Status s = reader.ReadTensors(&read_tensors);
|
||||
Status s = reader->ReadTensors(&read_tensors);
|
||||
if (s.ok()) {
|
||||
profiler::TraceMe activity(
|
||||
[&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
|
||||
@ -787,7 +786,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
// Pulls one file off the filenames_ list and reads it through. When
|
||||
// all files are read, terminates.
|
||||
void ReadingFilesLoop(int i) {
|
||||
void ReadingFilesLoop(Env* env, int i) {
|
||||
auto cleanup = gtl::MakeCleanup([this]() {
|
||||
mutex_lock l(mu_);
|
||||
--num_active_threads_;
|
||||
@ -803,7 +802,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
VLOG(2) << "Starting to read: " << filename;
|
||||
}
|
||||
Status s = ReadFile(filename);
|
||||
Status s = ReadFile(env, filename);
|
||||
// If we get to the end of the file, it's a clean termination and
|
||||
// we are at the end of the file. If all files have been processed,
|
||||
// then we insert an end_of_sequence marker in the buffer and
|
||||
|
@ -209,13 +209,27 @@ Status Writer::WriteRecord(const absl::Cord& data) {
|
||||
}
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
Reader::Reader(RandomAccessFile* file, const string& compression_type,
|
||||
Status Reader::Create(Env* env, const std::string& filename,
|
||||
const string& compression_type, int version,
|
||||
const DataTypeVector& dtypes,
|
||||
std::unique_ptr<Reader>* out_reader) {
|
||||
*out_reader =
|
||||
absl::WrapUnique(new Reader(filename, compression_type, version, dtypes));
|
||||
|
||||
return (*out_reader)->Initialize(env);
|
||||
}
|
||||
|
||||
Reader::Reader(const std::string& filename, const string& compression_type,
|
||||
int version, const DataTypeVector& dtypes)
|
||||
: file_(file),
|
||||
input_stream_(new io::RandomAccessInputStream(file)),
|
||||
: filename_(filename),
|
||||
compression_type_(compression_type),
|
||||
version_(version),
|
||||
dtypes_(dtypes) {
|
||||
dtypes_(dtypes) {}
|
||||
|
||||
Status Reader::Initialize(Env* env) {
|
||||
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
|
||||
input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
|
||||
|
||||
#if defined(IS_SLIM_BUILD)
|
||||
if (compression_type_ != io::compression::kNone) {
|
||||
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
|
||||
@ -232,16 +246,16 @@ Reader::Reader(RandomAccessFile* file, const string& compression_type,
|
||||
} else if (compression_type_ == io::compression::kSnappy) {
|
||||
if (version_ == 0) {
|
||||
input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
|
||||
file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
|
||||
file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
|
||||
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
|
||||
} else {
|
||||
input_stream_ =
|
||||
absl::make_unique<io::BufferedInputStream>(file_, 64 << 20);
|
||||
absl::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20);
|
||||
}
|
||||
}
|
||||
#endif // IS_SLIM_BUILD
|
||||
simple_tensor_mask_.reserve(dtypes.size());
|
||||
for (const auto& dtype : dtypes) {
|
||||
simple_tensor_mask_.reserve(dtypes_.size());
|
||||
for (const auto& dtype : dtypes_) {
|
||||
if (DataTypeCanUseMemcpy(dtype)) {
|
||||
simple_tensor_mask_.push_back(true);
|
||||
num_simple_++;
|
||||
@ -250,6 +264,8 @@ Reader::Reader(RandomAccessFile* file, const string& compression_type,
|
||||
num_complex_++;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||
|
@ -114,12 +114,19 @@ class Reader {
|
||||
static constexpr const char* const kReadCord = "ReadCord";
|
||||
static constexpr const char* const kSeparator = "::";
|
||||
|
||||
explicit Reader(RandomAccessFile* file, const string& compression_type,
|
||||
int version, const DataTypeVector& dtypes);
|
||||
static Status Create(Env* env, const std::string& filename,
|
||||
const string& compression_type, int version,
|
||||
const DataTypeVector& dtypes,
|
||||
std::unique_ptr<Reader>* out_reader);
|
||||
|
||||
Status ReadTensors(std::vector<Tensor>* read_tensors);
|
||||
|
||||
private:
|
||||
explicit Reader(const std::string& filename, const string& compression_type,
|
||||
int version, const DataTypeVector& dtypes);
|
||||
|
||||
Status Initialize(Env* env);
|
||||
|
||||
Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
|
||||
|
||||
Status SnappyUncompress(
|
||||
@ -134,7 +141,8 @@ class Reader {
|
||||
Status ReadRecord(absl::Cord* record);
|
||||
#endif
|
||||
|
||||
RandomAccessFile* file_;
|
||||
std::string filename_;
|
||||
std::unique_ptr<RandomAccessFile> file_;
|
||||
std::unique_ptr<io::InputStreamInterface> input_stream_;
|
||||
const string compression_type_;
|
||||
const int version_;
|
||||
|
Loading…
Reference in New Issue
Block a user