[tf.data] Cleanup and refactoring of tf.data snapshot implementation.

PiperOrigin-RevId: 314866920
Change-Id: If68df309d14eee4b697e74242705da93c9b03a99
This commit is contained in:
Jiri Simsa 2020-06-04 22:22:41 -07:00 committed by TensorFlower Gardener
parent da59266da2
commit 9221044560
3 changed files with 270 additions and 235 deletions

View File

@ -65,8 +65,6 @@ namespace experimental {
// ==== Snapshot Implementation ====
namespace {
/* The current snapshot on-disk layout is as follows:
* /user/specified/path/
* - graphhash1/
@ -95,25 +93,6 @@ namespace {
* ...
*/
constexpr const char* const kShardDirectorySuffix = ".shard";
inline tstring HashDirectory(const tstring& path, const uint64 hash) {
return io::JoinPath(path, absl::StrFormat("%d", hash));
}
inline tstring RunDirectory(const tstring& hash_directory,
const uint64 run_id) {
return io::JoinPath(hash_directory, absl::StrFormat("%d", run_id));
}
inline tstring SnapshotShardDirectory(const tstring& run_directory,
const int64 snapshot_index) {
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", snapshot_index,
kShardDirectorySuffix));
}
} // namespace
class SnapshotDatasetV2Op::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
@ -248,118 +227,19 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Writer
IteratorStateReader* reader) override;
private:
struct BufferElement {
std::vector<Tensor> value;
bool end_of_sequence = false;
};
class WriterThread {
public:
explicit WriterThread(Writer* writer_iterator, Env* env, int64 file_index,
const tstring& shard_directory,
uint64 current_checkpoint_id) {
thread_ = absl::WrapUnique(env->StartThread(
ThreadOptions(), absl::StrCat("snapshot_writer_thread_", file_index),
[this, writer_iterator, env, shard_directory, current_checkpoint_id] {
RunWriterThread(writer_iterator, env, shard_directory,
current_checkpoint_id);
}));
}
void EnqueueTensors(const std::vector<Tensor>& tensors)
TF_LOCKS_EXCLUDED(deque_mu_) {
// Copy the Tensor to the deque for writing.
mutex_lock l(deque_mu_);
BufferElement be;
be.value = tensors;
deque_.push_back(std::move(be));
}
void DequeueTensors(BufferElement* be) TF_LOCKS_EXCLUDED(deque_mu_) {
mutex_lock l(deque_mu_);
deque_mu_.Await(
tensorflow::Condition(this, &WriterThread::DequeIsNotEmpty));
*be = deque_.front();
deque_.pop_front();
}
void StopThread() TF_LOCKS_EXCLUDED(deque_mu_) {
mutex_lock l(deque_mu_);
BufferElement be;
be.end_of_sequence = true;
deque_.push_back(std::move(be));
}
void RunWriterThread(Writer* writer_iterator, Env* env,
const tstring& shard_directory,
uint64 current_checkpoint_id) {
Status s = WriterThreadFn(writer_iterator, env, shard_directory,
current_checkpoint_id);
if (!s.ok()) {
mutex_lock l(writer_iterator->mu_);
writer_iterator->writer_status_ = s;
}
}
private:
bool DequeIsNotEmpty() TF_EXCLUSIVE_LOCKS_REQUIRED(deque_mu_) {
return !deque_.empty();
}
Status WriterThreadFn(Writer* writer_iterator, Env* env,
const tstring& shard_directory,
uint64 current_checkpoint_id) {
std::unique_ptr<snapshot_util::Writer> writer;
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
env,
snapshot_util::GetCurrentCheckpointFile(shard_directory,
current_checkpoint_id),
writer_iterator->dataset()->compression_, kFileFormatVersion,
writer_iterator->dataset()->output_dtypes(), &writer));
while (true) {
BufferElement be;
DequeueTensors(&be);
if (be.end_of_sequence) {
TF_RETURN_IF_ERROR(writer->Close());
break;
}
TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
}
return Status::OK();
}
// If both the writer `mu_` and this `deque_mu_` needs to be acquired, the
// writer `mu_` must be acquired first.
mutex deque_mu_;
std::deque<BufferElement> deque_ TF_GUARDED_BY(deque_mu_);
// This has to be last. During destruction, we need to make sure that
// thread_ is destroyed first as the thread destructor blocks on thread
// completion. If there are other member variables after this, they may get
// destroyed first before the thread finishes, potentially causing the
// thread to access invalid memory.
std::unique_ptr<Thread> thread_;
};
Status GetShardIndex(std::vector<Tensor>* tensors, int64* shard_index)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status GetShardIndex(IteratorContext* ctx, const std::vector<Tensor>& tensors,
int64* shard_index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status WriteMetadataFile(Env* env, bool finalized)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
void StopWriterThreads(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
void SignalEOF(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<int64, std::unique_ptr<WriterThread>> writer_threads_
TF_GUARDED_BY(mu_);
absl::flat_hash_map<int64, std::unique_ptr<snapshot_util::AsyncWriter>>
writers_ TF_GUARDED_BY(mu_);
Status writer_status_ TF_GUARDED_BY(mu_);
bool writers_closed_ TF_GUARDED_BY(mu_);
@ -495,7 +375,8 @@ Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
index_(0),
hash_dir_(HashDirectory(dataset()->path_, dataset()->hash_)) {}
hash_dir_(
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_)) {}
Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
IteratorContext* ctx) {
@ -557,8 +438,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
ctx->env(), hash_dir_, &metadata, &file_exists));
if (!file_exists) {
return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
" does not exist any more.");
@ -573,8 +454,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
} else {
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
ctx->env(), hash_dir_, &metadata, &file_exists));
// `pending_snapshot_expiry_seconds` is a legacy option where we would not
// write snapshots that we think were still on-going. We decided that this
@ -620,18 +501,20 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
TF_RETURN_IF_ERROR(
dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
tstring hash_dir = HashDirectory(dataset()->path_, dataset()->hash_);
auto hash_dir =
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_);
bool metadata_file_exists;
experimental::SnapshotMetadataRecord metadata;
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(hash_dir, &metadata,
&metadata_file_exists));
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
ctx->env(), hash_dir, &metadata, &metadata_file_exists));
auto run_dir = io::JoinPath(hash_dir, metadata.run_id());
auto run_dir = snapshot_util::RunDirectory(hash_dir, metadata.run_id());
std::vector<std::string> snapshot_shard_dirs;
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
io::JoinPath(run_dir,
absl::StrFormat("%s%s", "*", kShardDirectorySuffix)),
io::JoinPath(
run_dir,
absl::StrFormat("%s%s", "*", snapshot_util::kShardDirectorySuffix)),
&snapshot_shard_dirs));
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
@ -691,18 +574,18 @@ SnapshotDatasetV2Op::Dataset::Iterator::Writer::Writer(const Params& params)
SnapshotDatasetV2Op::Dataset::Iterator::Writer::~Writer() {
mutex_lock l(mu_);
StopWriterThreads(true);
SignalEOF(true);
}
void SnapshotDatasetV2Op::Dataset::Iterator::Writer::StopWriterThreads(
bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
void SnapshotDatasetV2Op::Dataset::Iterator::Writer::SignalEOF(bool mark_closed)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!writers_closed_) {
// Push the end of sequence signal to each of the threads to close files.
for (auto& writer_thread : writer_threads_) {
writer_thread.second->StopThread();
for (auto& writer : writers_) {
writer.second->SignalEOF();
}
writer_threads_.clear();
writers_.clear();
writers_closed_ = mark_closed;
}
}
@ -720,10 +603,10 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
metadata.add_dtype(output_dtype);
}
metadata.set_finalized(finalized);
tstring hash_directory = HashDirectory(dataset()->path_, dataset()->hash_);
tstring hash_directory =
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_);
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(hash_directory));
return snapshot_util::WriteMetadataFile(hash_directory, &metadata);
return snapshot_util::WriteMetadataFile(env, hash_directory, &metadata);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
@ -737,12 +620,13 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
std::vector<Tensor>* tensors, int64* shard_index) {
IteratorContext* ctx, const std::vector<Tensor>& tensors,
int64* shard_index) {
std::vector<Tensor> output_tensors;
// Run the shard function
TF_RETURN_IF_ERROR(
instantiated_shard_func_->RunInstantiated(*tensors, &output_tensors));
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
ctx, tensors, &output_tensors));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) {
@ -758,7 +642,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
*end_of_sequence = false;
WriterThread* current_writer_thread;
snapshot_util::AsyncWriter* current_writer;
{
std::vector<Tensor> output_tensors;
@ -772,8 +656,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
run_id_ = random::New64();
// Creates the run directory.
run_dir_ = RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_),
run_id_);
run_dir_ = snapshot_util::RunDirectory(
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_),
run_id_);
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
}
@ -788,27 +673,33 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
// Finalize metadata file when we are at the end of the iterator.
if (*end_of_sequence) {
StopWriterThreads(/*mark_closed=*/true);
SignalEOF(/*mark_closed=*/true);
TF_RETURN_IF_ERROR(writer_status_);
return WriteMetadataFile(ctx->env(), /*finalized=*/true);
}
int64 shard_index = 0;
TF_RETURN_IF_ERROR(GetShardIndex(out_tensors, &shard_index));
TF_RETURN_IF_ERROR(GetShardIndex(ctx, *out_tensors, &shard_index));
// If the index does not exist, we will start a new thread.
if (writer_threads_.count(shard_index) == 0) {
const tstring snapshot_shard_directory =
SnapshotShardDirectory(run_dir_, shard_index);
auto thread_data = std::make_unique<WriterThread>(
this, ctx->env(), shard_index, snapshot_shard_directory,
current_checkpoint_id_);
writer_threads_.insert({shard_index, std::move(thread_data)});
if (writers_.count(shard_index) == 0) {
auto snapshot_shard_directory =
snapshot_util::ShardDirectory(run_dir_, shard_index);
auto writer = std::make_unique<snapshot_util::AsyncWriter>(
ctx->env(), shard_index, snapshot_shard_directory,
current_checkpoint_id_, dataset()->compression_, kFileFormatVersion,
dataset()->output_dtypes(), [this](Status s) {
if (!s.ok()) {
mutex_lock l(mu_);
writer_status_ = s;
}
});
writers_.insert({shard_index, std::move(writer)});
}
current_writer_thread = writer_threads_[shard_index].get();
current_writer = writers_[shard_index].get();
}
current_writer_thread->EnqueueTensors(*out_tensors);
current_writer->Write(*out_tensors);
return Status::OK();
}
@ -820,11 +711,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::SaveInternal(
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kCurrentCheckpointId),
static_cast<int64>(current_checkpoint_id_)));
StopWriterThreads(/*mark_closed=*/false);
writer_threads_.clear();
SignalEOF(/*mark_closed=*/false);
writers_.clear();
current_checkpoint_id_++;
return SaveInput(ctx, writer, input_impl_);
}
@ -839,8 +728,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
&current_checkpoint_id));
run_id_ = static_cast<uint64>(run_id_signed);
run_dir_ =
RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_), run_id_);
run_dir_ = snapshot_util::RunDirectory(
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_),
run_id_);
current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
return RestoreInput(ctx, reader, input_impl_);
@ -1056,7 +946,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
Status dump_status =
snapshot_util::DumpDatasetGraph(path, hash, &graph_def);
snapshot_util::DumpDatasetGraph(ctx->env(), path, hash, &graph_def);
if (!dump_status.ok()) {
LOG(WARNING) << "Unable to write graphdef to disk, error: "
<< dump_status.ToString();
@ -1244,7 +1134,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
hash_dir_, &metadata, &file_exists));
ctx->env(), hash_dir_, &metadata, &file_exists));
TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
dataset()->mode_, file_exists, &metadata,
dataset()->pending_snapshot_expiry_seconds_, &state_));
@ -1284,8 +1174,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(hash_dir_, &metadata,
&file_exists));
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
ctx->env(), hash_dir_, &metadata, &file_exists));
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
VLOG(2) << "Restoring Snapshot iterator: " << state_;
return RestoreInput(ctx, reader, iterator_);
@ -1600,7 +1490,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
Status ReadFile(Env* env, const string& filename) {
std::unique_ptr<snapshot_util::Reader> reader;
TF_RETURN_IF_ERROR(snapshot_util::Reader::Create(
Env::Default(), filename, dataset()->compression_, version_,
env, filename, dataset()->compression_, version_,
dataset()->output_dtypes(), &reader));
while (true) {
// Wait for a slot in the buffer.
@ -1813,8 +1703,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
// If we're restoring then the directory already exists and we
// don't want to overwrite the snapshot metadata file.
if (!is_restored_) {
TF_RETURN_IF_ERROR(
Env::Default()->RecursivelyCreateDir(run_dir_));
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
experimental::SnapshotMetadataRecord metadata;
metadata.set_creation_timestamp(EnvTime::NowMicros());
metadata.set_graph_hash(dataset()->graph_hash_);
@ -1824,8 +1713,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
metadata.add_dtype(output_dtype);
}
metadata.set_finalized(false);
TF_RETURN_IF_ERROR(
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile(
ctx->env(), hash_dir_, &metadata));
}
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
++num_active_threads_;
@ -2069,11 +1958,6 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
private:
struct BufferElement {
std::vector<Tensor> value;
bool end_of_sequence;
};
string GetSnapshotFilename() {
mutex_lock l(mu_);
string snapshot_data_filename = io::JoinPath(
@ -2085,7 +1969,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
Status FillBuffer(IteratorContext* ctx) TF_LOCKS_EXCLUDED(mu_) {
BufferElement elem;
snapshot_util::ElementOrEOF elem;
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, &elem.value, &elem.end_of_sequence));
@ -2119,16 +2003,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
"maximum size: ", dataset()->writer_buffer_size_);
}
BufferElement elem_copy = next_elem_;
snapshot_util::ElementOrEOF elem_copy = next_elem_;
buffer_.push_back(elem_copy);
cond_var_.notify_all();
return Status::OK();
}
Status ProcessOneElement(int64* bytes_written,
Status ProcessOneElement(Env* env, int64* bytes_written,
string* snapshot_data_filename,
std::unique_ptr<snapshot_util::Writer>* writer,
bool* end_of_processing, Env* env) {
bool* end_of_processing) {
profiler::TraceMe activity(
[&]() {
return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
@ -2138,7 +2022,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
*end_of_processing = false;
bool produced_elem = false;
bool snapshot_failed = false;
BufferElement elem;
snapshot_util::ElementOrEOF elem;
{
mutex_lock l(mu_);
// Wait for buffer to not be empty.
@ -2175,7 +2059,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
bool should_close;
TF_RETURN_IF_ERROR(
ShouldCloseWriter(*snapshot_data_filename, *bytes_written,
ShouldCloseWriter(env, *snapshot_data_filename, *bytes_written,
(*writer).get(), &should_close));
if (should_close) {
// If we exceed the shard size, we get a new file and reset.
@ -2198,12 +2082,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
hash_dir_, &metadata, &file_exists));
env, hash_dir_, &metadata, &file_exists));
if (metadata.run_id() == run_id_) {
metadata.set_finalized(true);
TF_RETURN_IF_ERROR(
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile(
env, hash_dir_, &metadata));
} else {
// TODO(frankchn): We lost the race, remove all snapshots.
}
@ -2240,8 +2124,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
bool end_of_processing = false;
while (!end_of_processing) {
Status s =
ProcessOneElement(&bytes_written, &snapshot_data_filename,
&writer, &end_of_processing, env);
ProcessOneElement(env, &bytes_written, &snapshot_data_filename,
&writer, &end_of_processing);
if (!s.ok()) {
LOG(INFO) << "Error while writing snapshot data to disk: "
<< s.ToString();
@ -2255,7 +2139,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
}
Status ShouldCloseWriter(const string& filename, uint64 bytes_written,
Status ShouldCloseWriter(Env* env, const string& filename,
uint64 bytes_written,
snapshot_util::Writer* writer,
bool* should_close) {
// If the compression ratio has been estimated, use it to decide
@ -2280,7 +2165,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
// Make sure that all bytes are written out.
TF_RETURN_IF_ERROR(writer->Sync());
uint64 file_size;
TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size));
TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
mutex_lock l(mu_);
compression_ratio_ = static_cast<double>(bytes_written) /
static_cast<double>(file_size);
@ -2300,7 +2185,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
// 5. By the background threads when they finish.
condition_variable cond_var_;
BufferElement next_elem_ TF_GUARDED_BY(mu_);
snapshot_util::ElementOrEOF next_elem_ TF_GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_;
const string hash_dir_;
@ -2313,7 +2198,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
int64 time_spent_micros_ TF_GUARDED_BY(mu_) = 0;
int64 bytes_produced_ TF_GUARDED_BY(mu_) = 0;
std::deque<BufferElement> buffer_ TF_GUARDED_BY(mu_);
std::deque<snapshot_util::ElementOrEOF> buffer_ TF_GUARDED_BY(mu_);
bool snapshot_failed_ TF_GUARDED_BY(mu_) = false;
bool cancelled_ TF_GUARDED_BY(mu_) = false;
bool first_call_ TF_GUARDED_BY(mu_) = true;

View File

@ -49,10 +49,27 @@ namespace snapshot_util {
/* static */ constexpr const int64
CustomReader::kSnappyReaderOutputBufferSizeBytes;
std::string GetCurrentCheckpointFile(const std::string& shard_directory,
const uint64 current_checkpoint_id) {
std::string HashDirectory(const std::string& path, uint64 hash) {
return io::JoinPath(path, absl::StrFormat("%d", hash));
}
std::string RunDirectory(const std::string& hash_directory, uint64 run_id) {
return RunDirectory(hash_directory, absl::StrFormat("%d", run_id));
}
std::string RunDirectory(const std::string& hash_directory,
const std::string& run_id) {
return io::JoinPath(hash_directory, run_id);
}
std::string ShardDirectory(const std::string& run_directory, int64 shard_id) {
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", shard_id,
kShardDirectorySuffix));
}
std::string GetCheckpointFileName(const std::string& shard_directory,
uint64 checkpoint_id) {
return io::JoinPath(shard_directory,
absl::StrFormat("%08d.snapshot", current_checkpoint_id));
absl::StrFormat("%08d.snapshot", checkpoint_id));
}
Status Writer::Create(Env* env, const std::string& filename,
@ -357,13 +374,6 @@ class Reader::Dataset : public DatasetBase {
}
private:
const std::string shard_dir_;
const std::string compression_;
const int64 version_;
const DataTypeVector dtypes_;
const std::vector<PartialTensorShape> shapes_;
const int64 start_index_;
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
@ -417,16 +427,9 @@ class Reader::Dataset : public DatasetBase {
}
private:
std::unique_ptr<Reader> reader_;
// Stores the id current checkpoint file that we are in the process of
// reading (e.g. if the file is currently 00000001.snapshot, then this will
// be 1).
uint64 current_checkpoint_id_;
std::string GetCurrentFilename() {
return GetCurrentCheckpointFile(dataset()->shard_dir_,
current_checkpoint_id_);
return GetCheckpointFileName(dataset()->shard_dir_,
current_checkpoint_id_);
}
Status AdvanceToNextFile(Env* env) {
@ -435,7 +438,21 @@ class Reader::Dataset : public DatasetBase {
return Reader::Create(env, GetCurrentFilename(), dataset()->compression_,
dataset()->version_, dataset()->dtypes_, &reader_);
}
std::unique_ptr<Reader> reader_;
// Stores the id current checkpoint file that we are in the process of
// reading (e.g. if the file is currently 00000001.snapshot, then this will
// be 1).
uint64 current_checkpoint_id_;
};
const std::string shard_dir_;
const std::string compression_;
const int64 version_;
const DataTypeVector dtypes_;
const std::vector<PartialTensorShape> shapes_;
const int64 start_index_;
};
class Reader::NestedDataset : public DatasetBase {
@ -571,7 +588,7 @@ TFRecordReader::TFRecordReader(const std::string& filename,
dtypes_(dtypes) {}
Status TFRecordReader::Initialize(Env* env) {
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
record_reader_ = absl::make_unique<io::RecordReader>(
file_.get(), io::RecordReaderOptions::CreateRecordReaderOptions(
@ -607,7 +624,7 @@ CustomReader::CustomReader(const std::string& filename,
dtypes_(dtypes) {}
Status CustomReader::Initialize(Env* env) {
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
#if defined(IS_SLIM_BUILD)
@ -807,31 +824,31 @@ Status CustomReader::ReadRecord(absl::Cord* record) {
}
#endif
Status WriteMetadataFile(const string& hash_dir,
Status WriteMetadataFile(Env* env, const string& dir,
const experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
string metadata_filename = io::JoinPath(dir, kMetadataFilename);
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir));
std::string tmp_filename =
absl::StrCat(metadata_filename, "-tmp-", random::New64());
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, *metadata));
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
TF_RETURN_IF_ERROR(WriteBinaryProto(env, tmp_filename, *metadata));
return env->RenameFile(tmp_filename, metadata_filename);
}
Status ReadMetadataFile(const string& hash_dir,
Status ReadMetadataFile(Env* env, const string& dir,
experimental::SnapshotMetadataRecord* metadata,
bool* file_exists) {
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
Status s = Env::Default()->FileExists(metadata_filename);
string metadata_filename = io::JoinPath(dir, kMetadataFilename);
Status s = env->FileExists(metadata_filename);
*file_exists = s.ok();
if (*file_exists) {
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
return ReadBinaryProto(env, metadata_filename, metadata);
} else {
return Status::OK();
}
}
Status DumpDatasetGraph(const std::string& path, uint64 hash,
Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
const GraphDef* graph) {
std::string hash_hex =
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
@ -839,8 +856,8 @@ Status DumpDatasetGraph(const std::string& path, uint64 hash,
io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt"));
LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file;
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(path));
return WriteTextProto(Env::Default(), graph_file, *graph);
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(path));
return WriteTextProto(env, graph_file, *graph);
}
Status DetermineOpState(const std::string& mode_string, bool file_exists,
@ -893,6 +910,68 @@ Status DetermineOpState(const std::string& mode_string, bool file_exists,
}
}
AsyncWriter::AsyncWriter(Env* env, int64 file_index,
const std::string& shard_directory,
uint64 checkpoint_id, const std::string& compression,
int64 version, const DataTypeVector& output_types,
std::function<void(Status)> done) {
thread_ = absl::WrapUnique(env->StartThread(
ThreadOptions(), absl::StrCat("writer_thread_", file_index),
[this, env, shard_directory, checkpoint_id, compression, version,
&output_types, done = std::move(done)] {
done(WriterThread(env, shard_directory, checkpoint_id, compression,
version, output_types));
}));
}
void AsyncWriter::Write(const std::vector<Tensor>& tensors) {
mutex_lock l(mu_);
ElementOrEOF element;
element.value = tensors;
deque_.push_back(std::move(element));
}
void AsyncWriter::SignalEOF() {
mutex_lock l(mu_);
ElementOrEOF be;
be.end_of_sequence = true;
deque_.push_back(std::move(be));
}
void AsyncWriter::Consume(ElementOrEOF* be) {
mutex_lock l(mu_);
mu_.Await(tensorflow::Condition(this, &AsyncWriter::ElementAvailable));
*be = deque_.front();
deque_.pop_front();
}
bool AsyncWriter::ElementAvailable() { return !deque_.empty(); }
Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory,
uint64 checkpoint_id,
const std::string& compression, int64 version,
DataTypeVector output_types) {
std::unique_ptr<snapshot_util::Writer> writer;
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
env, GetCheckpointFileName(shard_directory, checkpoint_id), compression,
version, std::move(output_types), &writer));
while (true) {
ElementOrEOF be;
Consume(&be);
if (be.end_of_sequence) {
TF_RETURN_IF_ERROR(writer->Close());
break;
}
TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
}
return Status::OK();
}
} // namespace snapshot_util
} // namespace data
} // namespace tensorflow

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
@ -48,11 +49,25 @@ constexpr char kModeAuto[] = "auto";
constexpr char kModeWrite[] = "write";
constexpr char kModeRead[] = "read";
constexpr char kModePassthrough[] = "passthrough";
constexpr char kShardDirectorySuffix[] = ".shard";
enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
std::string GetCurrentCheckpointFile(const std::string& shard_directory,
const uint64 current_checkpoint_id);
// Returns the name of the "hash" directory for the given base path and hash ID.
std::string HashDirectory(const std::string& path, uint64 hash);
// Returns the name of the "run" directory for the given base path and run ID.
std::string RunDirectory(const std::string& hash_directory, uint64 run_id);
std::string RunDirectory(const std::string& hash_directory,
const std::string& run_id);
// Returns the name of the "shard" directory for the given base path and shard
// ID.
std::string ShardDirectory(const std::string& run_directory, int64 shard_id);
// Returns the checkpoint file name for the given directory and checkpoint ID.
std::string GetCheckpointFileName(const std::string& shard_directory,
const uint64 checkpoint_id);
// This is a interface class that exposes snapshot writing functionality.
class Writer {
@ -265,14 +280,17 @@ class CustomReader : public Reader {
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
};
Status WriteMetadataFile(const string& hash_dir,
// Writes snapshot metadata to the given directory.
Status WriteMetadataFile(Env* env, const string& dir,
const experimental::SnapshotMetadataRecord* metadata);
Status ReadMetadataFile(const string& hash_dir,
// Reads snapshot metadata from the given directory.
Status ReadMetadataFile(Env* env, const string& dir,
experimental::SnapshotMetadataRecord* metadata,
bool* file_exists);
Status DumpDatasetGraph(const std::string& path, uint64 hash,
// Writes a dataset graph to the given directory.
Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
const GraphDef* graph);
Status DetermineOpState(const std::string& mode_string, bool file_exists,
@ -280,6 +298,59 @@ Status DetermineOpState(const std::string& mode_string, bool file_exists,
const uint64 pending_snapshot_expiry_seconds,
Mode* mode);
// Represents a dataset element or EOF.
struct ElementOrEOF {
std::vector<Tensor> value;
bool end_of_sequence = false;
};
// AsyncWriter provides API for asynchronously writing dataset elements
// (each represented as a vector of tensors) to a file.
//
// The expected use of this API is:
//
// std::unique_ptr<AsyncWriter> writer = absl_make_unique<AsyncWriter>(...);
//
// while (data_available()) {
// std::vector<Tensor> data = read_data()
// writer->Write(data);
// }
// writer->SignalEOF();
// writer = nullptr; // This will block until writes are flushed.
class AsyncWriter {
public:
explicit AsyncWriter(Env* env, int64 file_index,
const std::string& shard_directory, uint64 checkpoint_id,
const std::string& compression, int64 version,
const DataTypeVector& output_types,
std::function<void(Status)> done);
// Writes the given tensors. The method is non-blocking and returns without
// waiting for the element to be written.
void Write(const std::vector<Tensor>& tensors) TF_LOCKS_EXCLUDED(mu_);
// Signals the end of input. The method is non-blocking and returns without
// waiting for the writer to be closed.
void SignalEOF() TF_LOCKS_EXCLUDED(mu_);
private:
void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_);
bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status WriterThread(Env* env, const std::string& shard_directory,
uint64 checkpoint_id, const std::string& compression,
int64 version, DataTypeVector output_types);
mutex mu_;
std::deque<ElementOrEOF> deque_ TF_GUARDED_BY(mu_);
// This has to be last. During destruction, we need to make sure that the
// Thread object is destroyed first as its destructor blocks on thread
// completion. If there are other member variables after this, they may get
// destroyed first before the thread finishes, potentially causing the
// thread to access invalid memory.
std::unique_ptr<Thread> thread_;
};
} // namespace snapshot_util
} // namespace data
} // namespace tensorflow