Changing Snapshot to move on to the next file on the basis of number of bytes written instead of number of IteratorGetNext calls.
PiperOrigin-RevId: 249153172
This commit is contained in:
parent
0f1841a87a
commit
c93ebee01e
@ -317,6 +317,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "absl/time/clock.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -40,13 +41,13 @@ const uint64 kReaderBufferSize = 8 * 1024 * 1024; // 8 MB
|
||||
|
||||
const uint64 kOneDayInMicroseconds = 24L * 60L * 60L * 1e6L;
|
||||
|
||||
const uint64 kNumElementsPerShard = 10000;
|
||||
const uint64 kNumMBPerShard = 10 * 1024; // 10 GB per file.
|
||||
|
||||
const char kSnapshotFilename[] = "snapshot.metadata";
|
||||
|
||||
string GetCurrentSnapshotDataFilename(uint64 next_index,
|
||||
string GetCurrentSnapshotDataFilename(uint64 bytes_written,
|
||||
const string& run_dir) {
|
||||
uint64_t shard_id = next_index / kNumElementsPerShard;
|
||||
uint64_t shard_id = bytes_written / (1024 * 1024 * kNumMBPerShard);
|
||||
return absl::StrCat(run_dir, "/", strings::Printf("%08lu", shard_id),
|
||||
".snapshot");
|
||||
}
|
||||
@ -55,7 +56,6 @@ Status WriteMetadataFile(const string& fingerprint_dir,
|
||||
const experimental::SnapshotMetadataRecord& metadata) {
|
||||
string metadata_filename =
|
||||
absl::StrCat(fingerprint_dir, "/", kSnapshotFilename);
|
||||
|
||||
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(fingerprint_dir));
|
||||
|
||||
std::unique_ptr<WritableFile> file;
|
||||
@ -291,85 +291,112 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
run_id_ = metadata_.run_id();
|
||||
run_dir_ = absl::StrCat(dataset()->reader_path_prefix_,
|
||||
fingerprint_dir_, "/", run_id_);
|
||||
run_dir_ = absl::StrCat(fingerprint_dir_, "/", run_id_);
|
||||
// Get all the files in the run_dir.
|
||||
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
|
||||
absl::StrCat(run_dir_, "/*"), &filenames_));
|
||||
if (filenames_.empty()) {
|
||||
return errors::InvalidArgument("Could not find any files in dir: ",
|
||||
run_dir_);
|
||||
}
|
||||
std::sort(filenames_.begin(), filenames_.end());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
absl::Time start = absl::Now();
|
||||
mutex_lock l(mu_);
|
||||
do {
|
||||
if (current_reader_) {
|
||||
string record_bytes;
|
||||
Status s = current_reader_->ReadRecord(&record_bytes);
|
||||
if (s.ok()) {
|
||||
*end_of_sequence = false;
|
||||
experimental::SnapshotRecord record;
|
||||
record.ParseFromString(record_bytes);
|
||||
int64 num_bytes = 0;
|
||||
for (int i = 0; i < record.tensor_size(); ++i) {
|
||||
Tensor t;
|
||||
if (!t.FromProto(record.tensor(i))) {
|
||||
return errors::DataLoss(
|
||||
"Unable to parse Tensor from proto.");
|
||||
}
|
||||
out_tensors->push_back(t);
|
||||
num_bytes += t.TotalBytes();
|
||||
}
|
||||
absl::Time end = absl::Now();
|
||||
absl::Duration d = end - start;
|
||||
time_spent_micros_ += absl::ToInt64Microseconds(d);
|
||||
kbytes_written_ += static_cast<double>(num_bytes) / 1024.0;
|
||||
next_index_++;
|
||||
|
||||
string snapshot_data_filename =
|
||||
GetCurrentSnapshotDataFilename(next_index_, run_dir_);
|
||||
if (next_index_ % 10000 == 0) {
|
||||
LOG(INFO) << "Current read throughput (MBPS): "
|
||||
<< (kbytes_written_ * 1000000.0) /
|
||||
(time_spent_micros_ * 1024.0);
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (!errors::IsOutOfRange(s)) {
|
||||
// Report non-EOF errors to the caller.
|
||||
return s;
|
||||
}
|
||||
// Now that we're reached the end of the current file, lets move
|
||||
// on to the next file.
|
||||
ResetReaderLocked();
|
||||
++current_file_index_;
|
||||
}
|
||||
|
||||
if (current_read_filename_ != snapshot_data_filename) {
|
||||
current_reader_.reset();
|
||||
current_read_file_.reset();
|
||||
|
||||
// The current implementation here assumes that tensors are stored
|
||||
// in files which are named sequentially. If a file doesn't exist
|
||||
// when we try reading that item, we assume that we have reached the
|
||||
// end of the snapshot.
|
||||
Status s = Env::Default()->FileExists(snapshot_data_filename);
|
||||
if (!s.ok()) {
|
||||
if (current_file_index_ == filenames_.size()) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(
|
||||
snapshot_data_filename, ¤t_read_file_));
|
||||
auto reader_options =
|
||||
io::RecordReaderOptions::CreateRecordReaderOptions(
|
||||
dataset()->compression_);
|
||||
reader_options.buffer_size = kReaderBufferSize;
|
||||
|
||||
current_reader_ = absl::make_unique<io::SequentialRecordReader>(
|
||||
current_read_file_.get(), reader_options);
|
||||
current_read_filename_ = snapshot_data_filename;
|
||||
}
|
||||
|
||||
string record_bytes;
|
||||
Status s = current_reader_->ReadRecord(&record_bytes);
|
||||
|
||||
if (errors::IsOutOfRange(s)) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
} else if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
|
||||
*end_of_sequence = false;
|
||||
experimental::SnapshotRecord record;
|
||||
record.ParseFromString(record_bytes);
|
||||
|
||||
for (int i = 0; i < record.tensor_size(); ++i) {
|
||||
Tensor t;
|
||||
if (!t.FromProto(record.tensor(i))) {
|
||||
return errors::DataLoss("Unable to parse Tensor from proto.");
|
||||
}
|
||||
out_tensors->push_back(t);
|
||||
}
|
||||
|
||||
next_index_++;
|
||||
return Status::OK();
|
||||
TF_RETURN_IF_ERROR(SetupReaderLocked(ctx->env()));
|
||||
} while (true);
|
||||
}
|
||||
|
||||
private:
|
||||
Status SetupReaderLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (current_file_index_ >= filenames_.size()) {
|
||||
return errors::InvalidArgument("current_files_index_...");
|
||||
}
|
||||
const string filename = absl::StrCat(dataset()->reader_path_prefix_,
|
||||
filenames_[current_file_index_]);
|
||||
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename,
|
||||
¤t_read_file_));
|
||||
auto reader_options =
|
||||
io::RecordReaderOptions::CreateRecordReaderOptions(
|
||||
dataset()->compression_);
|
||||
reader_options.buffer_size = kReaderBufferSize;
|
||||
|
||||
current_reader_ = absl::make_unique<io::SequentialRecordReader>(
|
||||
current_read_file_.get(), reader_options);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ResetReaderLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
current_reader_.reset();
|
||||
current_read_file_.reset();
|
||||
}
|
||||
|
||||
const string fingerprint_dir_;
|
||||
const experimental::SnapshotMetadataRecord metadata_;
|
||||
string run_id_ GUARDED_BY(mu_);
|
||||
string run_dir_ GUARDED_BY(mu_);
|
||||
std::vector<string> filenames_;
|
||||
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
|
||||
string current_read_filename_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<RandomAccessFile> current_read_file_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<io::SequentialRecordReader> current_reader_
|
||||
GUARDED_BY(mu_);
|
||||
|
||||
int64 next_index_ GUARDED_BY(mu_) = 0;
|
||||
uint64 next_index_ GUARDED_BY(mu_) = 0;
|
||||
int64 time_spent_micros_ GUARDED_BY(mu_) = 0;
|
||||
double kbytes_written_ GUARDED_BY(mu_) = 0;
|
||||
size_t current_file_index_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
mutex mu_;
|
||||
};
|
||||
@ -405,6 +432,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
absl::Time start = absl::Now();
|
||||
mutex_lock l(mu_);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -433,7 +461,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
string snapshot_data_filename =
|
||||
GetCurrentSnapshotDataFilename(next_index_, run_dir_);
|
||||
GetCurrentSnapshotDataFilename(bytes_written_, run_dir_);
|
||||
|
||||
if (current_write_filename_ != snapshot_data_filename) {
|
||||
if (current_writer_) TF_RETURN_IF_ERROR(current_writer_->Close());
|
||||
@ -456,7 +484,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
experimental::SnapshotRecord record;
|
||||
|
||||
int64 num_bytes = 0;
|
||||
for (auto out_tensor : *out_tensors) {
|
||||
num_bytes += out_tensor.TotalBytes();
|
||||
TensorProto* t = record.add_tensor();
|
||||
out_tensor.AsProtoTensorContent(t);
|
||||
}
|
||||
@ -464,7 +494,18 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
TF_RETURN_IF_ERROR(
|
||||
current_writer_->WriteRecord(record.SerializeAsString()));
|
||||
|
||||
absl::Time end = absl::Now();
|
||||
absl::Duration d = end - start;
|
||||
time_spent_micros_ += absl::ToInt64Microseconds(d);
|
||||
bytes_written_ += num_bytes;
|
||||
|
||||
next_index_++;
|
||||
|
||||
if (next_index_ % 10000 == 0) {
|
||||
LOG(INFO) << "Current write throughput (MBPS): "
|
||||
<< (bytes_written_ * 1000000.0) /
|
||||
(time_spent_micros_ * 1024.0 * 1024.0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -480,6 +521,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::unique_ptr<io::RecordWriter> current_writer_ GUARDED_BY(mu_);
|
||||
|
||||
uint64 next_index_ GUARDED_BY(mu_) = 0;
|
||||
int64 time_spent_micros_ GUARDED_BY(mu_) = 0;
|
||||
int64 bytes_written_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
mutex mu_;
|
||||
};
|
||||
|
@ -128,15 +128,6 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
|
||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
||||
|
||||
def testWriteSnapshotMultiFileSuccessful(self):
|
||||
tmpdir = self.makeSnapshotDirectory()
|
||||
|
||||
dataset = dataset_ops.Dataset.range(20000)
|
||||
dataset = dataset.apply(snapshot.snapshot(tmpdir))
|
||||
self.assertDatasetProduces(dataset, list(range(20000)))
|
||||
|
||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 2)
|
||||
|
||||
@parameterized.parameters(snapshot.COMPRESSION_NONE,
|
||||
snapshot.COMPRESSION_GZIP)
|
||||
def testReadSnapshotBackAfterWrite(self, compression):
|
||||
|
Loading…
x
Reference in New Issue
Block a user