Changing Snapshot to move on to the next file on the basis of number of bytes written instead of number of IteratorGetNext calls.

PiperOrigin-RevId: 249153172
This commit is contained in:
Rohan Jain 2019-05-20 16:47:20 -07:00 committed by TensorFlower Gardener
parent 0f1841a87a
commit c93ebee01e
3 changed files with 101 additions and 66 deletions

View File

@ -317,6 +317,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/time",
],
)

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/time/clock.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
@ -40,13 +41,13 @@ const uint64 kReaderBufferSize = 8 * 1024 * 1024; // 8 MB
const uint64 kOneDayInMicroseconds = 24L * 60L * 60L * 1e6L;
const uint64 kNumElementsPerShard = 10000;
const uint64 kNumMBPerShard = 10 * 1024; // 10 GB per file.
const char kSnapshotFilename[] = "snapshot.metadata";
string GetCurrentSnapshotDataFilename(uint64 next_index,
string GetCurrentSnapshotDataFilename(uint64 bytes_written,
const string& run_dir) {
uint64_t shard_id = next_index / kNumElementsPerShard;
uint64_t shard_id = bytes_written / (1024 * 1024 * kNumMBPerShard);
return absl::StrCat(run_dir, "/", strings::Printf("%08lu", shard_id),
".snapshot");
}
@ -55,7 +56,6 @@ Status WriteMetadataFile(const string& fingerprint_dir,
const experimental::SnapshotMetadataRecord& metadata) {
string metadata_filename =
absl::StrCat(fingerprint_dir, "/", kSnapshotFilename);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(fingerprint_dir));
std::unique_ptr<WritableFile> file;
@ -291,85 +291,112 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
run_id_ = metadata_.run_id();
run_dir_ = absl::StrCat(dataset()->reader_path_prefix_,
fingerprint_dir_, "/", run_id_);
run_dir_ = absl::StrCat(fingerprint_dir_, "/", run_id_);
// Get all the files in the run_dir.
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
absl::StrCat(run_dir_, "/*"), &filenames_));
if (filenames_.empty()) {
return errors::InvalidArgument("Could not find any files in dir: ",
run_dir_);
}
std::sort(filenames_.begin(), filenames_.end());
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
absl::Time start = absl::Now();
mutex_lock l(mu_);
do {
if (current_reader_) {
string record_bytes;
Status s = current_reader_->ReadRecord(&record_bytes);
if (s.ok()) {
*end_of_sequence = false;
experimental::SnapshotRecord record;
record.ParseFromString(record_bytes);
int64 num_bytes = 0;
for (int i = 0; i < record.tensor_size(); ++i) {
Tensor t;
if (!t.FromProto(record.tensor(i))) {
return errors::DataLoss(
"Unable to parse Tensor from proto.");
}
out_tensors->push_back(t);
num_bytes += t.TotalBytes();
}
absl::Time end = absl::Now();
absl::Duration d = end - start;
time_spent_micros_ += absl::ToInt64Microseconds(d);
kbytes_written_ += static_cast<double>(num_bytes) / 1024.0;
next_index_++;
string snapshot_data_filename =
GetCurrentSnapshotDataFilename(next_index_, run_dir_);
if (next_index_ % 10000 == 0) {
LOG(INFO) << "Current read throughput (MBPS): "
<< (kbytes_written_ * 1000000.0) /
(time_spent_micros_ * 1024.0);
}
return Status::OK();
} else if (!errors::IsOutOfRange(s)) {
// Report non-EOF errors to the caller.
return s;
}
// Now that we're reached the end of the current file, lets move
// on to the next file.
ResetReaderLocked();
++current_file_index_;
}
if (current_read_filename_ != snapshot_data_filename) {
current_reader_.reset();
current_read_file_.reset();
// The current implementation here assumes that tensors are stored
// in files which are named sequentially. If a file doesn't exist
// when we try reading that item, we assume that we have reached the
// end of the snapshot.
Status s = Env::Default()->FileExists(snapshot_data_filename);
if (!s.ok()) {
if (current_file_index_ == filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(
snapshot_data_filename, &current_read_file_));
auto reader_options =
io::RecordReaderOptions::CreateRecordReaderOptions(
dataset()->compression_);
reader_options.buffer_size = kReaderBufferSize;
current_reader_ = absl::make_unique<io::SequentialRecordReader>(
current_read_file_.get(), reader_options);
current_read_filename_ = snapshot_data_filename;
}
string record_bytes;
Status s = current_reader_->ReadRecord(&record_bytes);
if (errors::IsOutOfRange(s)) {
*end_of_sequence = true;
return Status::OK();
} else if (!s.ok()) {
return s;
}
*end_of_sequence = false;
experimental::SnapshotRecord record;
record.ParseFromString(record_bytes);
for (int i = 0; i < record.tensor_size(); ++i) {
Tensor t;
if (!t.FromProto(record.tensor(i))) {
return errors::DataLoss("Unable to parse Tensor from proto.");
}
out_tensors->push_back(t);
}
next_index_++;
return Status::OK();
TF_RETURN_IF_ERROR(SetupReaderLocked(ctx->env()));
} while (true);
}
private:
Status SetupReaderLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= filenames_.size()) {
return errors::InvalidArgument("current_files_index_...");
}
const string filename = absl::StrCat(dataset()->reader_path_prefix_,
filenames_[current_file_index_]);
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename,
&current_read_file_));
auto reader_options =
io::RecordReaderOptions::CreateRecordReaderOptions(
dataset()->compression_);
reader_options.buffer_size = kReaderBufferSize;
current_reader_ = absl::make_unique<io::SequentialRecordReader>(
current_read_file_.get(), reader_options);
return Status::OK();
}
void ResetReaderLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
current_reader_.reset();
current_read_file_.reset();
}
const string fingerprint_dir_;
const experimental::SnapshotMetadataRecord metadata_;
string run_id_ GUARDED_BY(mu_);
string run_dir_ GUARDED_BY(mu_);
std::vector<string> filenames_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
string current_read_filename_ GUARDED_BY(mu_);
std::unique_ptr<RandomAccessFile> current_read_file_ GUARDED_BY(mu_);
std::unique_ptr<io::SequentialRecordReader> current_reader_
GUARDED_BY(mu_);
int64 next_index_ GUARDED_BY(mu_) = 0;
uint64 next_index_ GUARDED_BY(mu_) = 0;
int64 time_spent_micros_ GUARDED_BY(mu_) = 0;
double kbytes_written_ GUARDED_BY(mu_) = 0;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
mutex mu_;
};
@ -405,6 +432,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
absl::Time start = absl::Now();
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
@ -433,7 +461,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
string snapshot_data_filename =
GetCurrentSnapshotDataFilename(next_index_, run_dir_);
GetCurrentSnapshotDataFilename(bytes_written_, run_dir_);
if (current_write_filename_ != snapshot_data_filename) {
if (current_writer_) TF_RETURN_IF_ERROR(current_writer_->Close());
@ -456,7 +484,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
experimental::SnapshotRecord record;
int64 num_bytes = 0;
for (auto out_tensor : *out_tensors) {
num_bytes += out_tensor.TotalBytes();
TensorProto* t = record.add_tensor();
out_tensor.AsProtoTensorContent(t);
}
@ -464,7 +494,18 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(
current_writer_->WriteRecord(record.SerializeAsString()));
absl::Time end = absl::Now();
absl::Duration d = end - start;
time_spent_micros_ += absl::ToInt64Microseconds(d);
bytes_written_ += num_bytes;
next_index_++;
if (next_index_ % 10000 == 0) {
LOG(INFO) << "Current write throughput (MBPS): "
<< (bytes_written_ * 1000000.0) /
(time_spent_micros_ * 1024.0 * 1024.0);
}
return Status::OK();
}
@ -480,6 +521,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<io::RecordWriter> current_writer_ GUARDED_BY(mu_);
uint64 next_index_ GUARDED_BY(mu_) = 0;
int64 time_spent_micros_ GUARDED_BY(mu_) = 0;
int64 bytes_written_ GUARDED_BY(mu_) = 0;
mutex mu_;
};

View File

@ -128,15 +128,6 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testWriteSnapshotMultiFileSuccessful(self):
tmpdir = self.makeSnapshotDirectory()
dataset = dataset_ops.Dataset.range(20000)
dataset = dataset.apply(snapshot.snapshot(tmpdir))
self.assertDatasetProduces(dataset, list(range(20000)))
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 2)
@parameterized.parameters(snapshot.COMPRESSION_NONE,
snapshot.COMPRESSION_GZIP)
def testReadSnapshotBackAfterWrite(self, compression):