tensor_bundle: fix that the read path forgets to cache file handles.
In a case where a reader is geographically far from the file, this change achieves a speedup of end-to-end checkpoint restore by 5.8x. PiperOrigin-RevId: 157889659
This commit is contained in:
parent
0c92dada6a
commit
9b8f6113b7
@ -640,6 +640,12 @@ BundleReader::~BundleReader() {
|
||||
delete metadata_;
|
||||
delete iter_;
|
||||
delete table_;
|
||||
// InputBuffer does not own the underlying RandomAccessFile.
|
||||
for (auto pair : data_) {
|
||||
if (pair.second->file() != nullptr) {
|
||||
delete pair.second->file();
|
||||
}
|
||||
}
|
||||
gtl::STLDeleteValues(&data_);
|
||||
gtl::STLDeleteValues(&tensor_slices_);
|
||||
}
|
||||
@ -694,14 +700,16 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
|
||||
}
|
||||
}
|
||||
|
||||
// Open the data file if not opened it.
|
||||
std::unique_ptr<RandomAccessFile> file = nullptr;
|
||||
std::unique_ptr<io::InputBuffer> buffered_file(data_[entry.shard_id()]);
|
||||
// Open the data file if it has not been opened.
|
||||
io::InputBuffer* buffered_file = data_[entry.shard_id()];
|
||||
if (buffered_file == nullptr) {
|
||||
std::unique_ptr<RandomAccessFile> file = nullptr;
|
||||
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
|
||||
DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
|
||||
buffered_file.reset(
|
||||
new io::InputBuffer(file.get(), 256 << 10 /* 256KB buffer */));
|
||||
buffered_file =
|
||||
new io::InputBuffer(file.release(), 256 << 10 /* 256KB buffer */);
|
||||
// The InputBuffer and RandomAccessFile objects are both released in dtor.
|
||||
data_[entry.shard_id()] = buffered_file;
|
||||
}
|
||||
CHECK(buffered_file != nullptr);
|
||||
|
||||
@ -720,7 +728,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
|
||||
// Relies on io::InputBuffer's buffering, because we issue many neighboring
|
||||
// reads for a single string tensor.
|
||||
TF_RETURN_IF_ERROR(ReadStringTensor(
|
||||
buffered_file.get(), ret->NumElements(), entry.offset(), entry.size(),
|
||||
buffered_file, ret->NumElements(), entry.offset(), entry.size(),
|
||||
GetStringBackingBuffer(*ret), &actual_crc32c));
|
||||
}
|
||||
if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
|
||||
|
@ -273,6 +273,7 @@ class BundleReader {
|
||||
RandomAccessFile* metadata_; // Owned.
|
||||
table::Table* table_;
|
||||
table::Iterator* iter_;
|
||||
// Owned the InputBuffer objects and their underlying RandomAccessFile's.
|
||||
std::unordered_map<int32, io::InputBuffer*> data_;
|
||||
|
||||
// Maps each partitioned tensor's key to its stored slices (represented in a
|
||||
|
Loading…
Reference in New Issue
Block a user