tensor_bundle: fix that the read path forgets to cache file handles.

In a case where a reader is geographically far from the file, this change
achieves a speedup of end-to-end checkpoint restore by 5.8x.

PiperOrigin-RevId: 157889659
This commit is contained in:
Zongheng Yang 2017-06-02 16:28:10 -07:00 committed by TensorFlower Gardener
parent 0c92dada6a
commit 9b8f6113b7
2 changed files with 15 additions and 6 deletions

View File

@ -640,6 +640,12 @@ BundleReader::~BundleReader() {
delete metadata_;
delete iter_;
delete table_;
// InputBuffer does not own the underlying RandomAccessFile.
for (auto pair : data_) {
if (pair.second->file() != nullptr) {
delete pair.second->file();
}
}
gtl::STLDeleteValues(&data_);
gtl::STLDeleteValues(&tensor_slices_);
}
@ -694,14 +700,16 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
}
}
// Open the data file if not opened it.
std::unique_ptr<RandomAccessFile> file = nullptr;
std::unique_ptr<io::InputBuffer> buffered_file(data_[entry.shard_id()]);
// Open the data file if it has not been opened.
io::InputBuffer* buffered_file = data_[entry.shard_id()];
if (buffered_file == nullptr) {
std::unique_ptr<RandomAccessFile> file = nullptr;
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
buffered_file.reset(
new io::InputBuffer(file.get(), 256 << 10 /* 256KB buffer */));
buffered_file =
new io::InputBuffer(file.release(), 256 << 10 /* 256KB buffer */);
// The InputBuffer and RandomAccessFile objects are both released in dtor.
data_[entry.shard_id()] = buffered_file;
}
CHECK(buffered_file != nullptr);
@ -720,7 +728,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
// Relies on io::InputBuffer's buffering, because we issue many neighboring
// reads for a single string tensor.
TF_RETURN_IF_ERROR(ReadStringTensor(
buffered_file.get(), ret->NumElements(), entry.offset(), entry.size(),
buffered_file, ret->NumElements(), entry.offset(), entry.size(),
GetStringBackingBuffer(*ret), &actual_crc32c));
}
if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {

View File

@ -273,6 +273,7 @@ class BundleReader {
RandomAccessFile* metadata_; // Owned.
table::Table* table_;
table::Iterator* iter_;
// Owned the InputBuffer objects and their underlying RandomAccessFile's.
std::unordered_map<int32, io::InputBuffer*> data_;
// Maps each partitioned tensor's key to its stored slices (represented in a