[tf.data] Cleanup and refactoring of tf.data snapshot implementation.
PiperOrigin-RevId: 314866920 Change-Id: If68df309d14eee4b697e74242705da93c9b03a99
This commit is contained in:
parent
da59266da2
commit
9221044560
@ -65,8 +65,6 @@ namespace experimental {
|
||||
|
||||
// ==== Snapshot Implementation ====
|
||||
|
||||
namespace {
|
||||
|
||||
/* The current snapshot on-disk layout is as follows:
|
||||
* /user/specified/path/
|
||||
* - graphhash1/
|
||||
@ -95,25 +93,6 @@ namespace {
|
||||
* ...
|
||||
*/
|
||||
|
||||
constexpr const char* const kShardDirectorySuffix = ".shard";
|
||||
|
||||
inline tstring HashDirectory(const tstring& path, const uint64 hash) {
|
||||
return io::JoinPath(path, absl::StrFormat("%d", hash));
|
||||
}
|
||||
|
||||
inline tstring RunDirectory(const tstring& hash_directory,
|
||||
const uint64 run_id) {
|
||||
return io::JoinPath(hash_directory, absl::StrFormat("%d", run_id));
|
||||
}
|
||||
|
||||
inline tstring SnapshotShardDirectory(const tstring& run_directory,
|
||||
const int64 snapshot_index) {
|
||||
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", snapshot_index,
|
||||
kShardDirectorySuffix));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class SnapshotDatasetV2Op::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
|
||||
@ -248,118 +227,19 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Writer
|
||||
IteratorStateReader* reader) override;
|
||||
|
||||
private:
|
||||
struct BufferElement {
|
||||
std::vector<Tensor> value;
|
||||
bool end_of_sequence = false;
|
||||
};
|
||||
|
||||
class WriterThread {
|
||||
public:
|
||||
explicit WriterThread(Writer* writer_iterator, Env* env, int64 file_index,
|
||||
const tstring& shard_directory,
|
||||
uint64 current_checkpoint_id) {
|
||||
thread_ = absl::WrapUnique(env->StartThread(
|
||||
ThreadOptions(), absl::StrCat("snapshot_writer_thread_", file_index),
|
||||
[this, writer_iterator, env, shard_directory, current_checkpoint_id] {
|
||||
RunWriterThread(writer_iterator, env, shard_directory,
|
||||
current_checkpoint_id);
|
||||
}));
|
||||
}
|
||||
|
||||
void EnqueueTensors(const std::vector<Tensor>& tensors)
|
||||
TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||
// Copy the Tensor to the deque for writing.
|
||||
mutex_lock l(deque_mu_);
|
||||
BufferElement be;
|
||||
be.value = tensors;
|
||||
deque_.push_back(std::move(be));
|
||||
}
|
||||
|
||||
void DequeueTensors(BufferElement* be) TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||
mutex_lock l(deque_mu_);
|
||||
deque_mu_.Await(
|
||||
tensorflow::Condition(this, &WriterThread::DequeIsNotEmpty));
|
||||
*be = deque_.front();
|
||||
deque_.pop_front();
|
||||
}
|
||||
|
||||
void StopThread() TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||
mutex_lock l(deque_mu_);
|
||||
BufferElement be;
|
||||
be.end_of_sequence = true;
|
||||
deque_.push_back(std::move(be));
|
||||
}
|
||||
|
||||
void RunWriterThread(Writer* writer_iterator, Env* env,
|
||||
const tstring& shard_directory,
|
||||
uint64 current_checkpoint_id) {
|
||||
Status s = WriterThreadFn(writer_iterator, env, shard_directory,
|
||||
current_checkpoint_id);
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(writer_iterator->mu_);
|
||||
writer_iterator->writer_status_ = s;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool DequeIsNotEmpty() TF_EXCLUSIVE_LOCKS_REQUIRED(deque_mu_) {
|
||||
return !deque_.empty();
|
||||
}
|
||||
|
||||
Status WriterThreadFn(Writer* writer_iterator, Env* env,
|
||||
const tstring& shard_directory,
|
||||
uint64 current_checkpoint_id) {
|
||||
std::unique_ptr<snapshot_util::Writer> writer;
|
||||
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
|
||||
|
||||
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
|
||||
env,
|
||||
snapshot_util::GetCurrentCheckpointFile(shard_directory,
|
||||
current_checkpoint_id),
|
||||
writer_iterator->dataset()->compression_, kFileFormatVersion,
|
||||
writer_iterator->dataset()->output_dtypes(), &writer));
|
||||
|
||||
while (true) {
|
||||
BufferElement be;
|
||||
DequeueTensors(&be);
|
||||
|
||||
if (be.end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(writer->Close());
|
||||
break;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If both the writer `mu_` and this `deque_mu_` needs to be acquired, the
|
||||
// writer `mu_` must be acquired first.
|
||||
mutex deque_mu_;
|
||||
|
||||
std::deque<BufferElement> deque_ TF_GUARDED_BY(deque_mu_);
|
||||
|
||||
// This has to be last. During destruction, we need to make sure that
|
||||
// thread_ is destroyed first as the thread destructor blocks on thread
|
||||
// completion. If there are other member variables after this, they may get
|
||||
// destroyed first before the thread finishes, potentially causing the
|
||||
// thread to access invalid memory.
|
||||
std::unique_ptr<Thread> thread_;
|
||||
};
|
||||
|
||||
Status GetShardIndex(std::vector<Tensor>* tensors, int64* shard_index)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status GetShardIndex(IteratorContext* ctx, const std::vector<Tensor>& tensors,
|
||||
int64* shard_index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
Status WriteMetadataFile(Env* env, bool finalized)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
void StopWriterThreads(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
void SignalEOF(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
|
||||
absl::flat_hash_map<int64, std::unique_ptr<WriterThread>> writer_threads_
|
||||
TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<int64, std::unique_ptr<snapshot_util::AsyncWriter>>
|
||||
writers_ TF_GUARDED_BY(mu_);
|
||||
Status writer_status_ TF_GUARDED_BY(mu_);
|
||||
bool writers_closed_ TF_GUARDED_BY(mu_);
|
||||
|
||||
@ -495,7 +375,8 @@ Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
index_(0),
|
||||
hash_dir_(HashDirectory(dataset()->path_, dataset()->hash_)) {}
|
||||
hash_dir_(
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_)) {}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
|
||||
IteratorContext* ctx) {
|
||||
@ -557,8 +438,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
|
||||
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(
|
||||
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
ctx->env(), hash_dir_, &metadata, &file_exists));
|
||||
if (!file_exists) {
|
||||
return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
|
||||
" does not exist any more.");
|
||||
@ -573,8 +454,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
|
||||
} else {
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(
|
||||
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
ctx->env(), hash_dir_, &metadata, &file_exists));
|
||||
|
||||
// `pending_snapshot_expiry_seconds` is a legacy option where we would not
|
||||
// write snapshots that we think were still on-going. We decided that this
|
||||
@ -620,18 +501,20 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
|
||||
|
||||
tstring hash_dir = HashDirectory(dataset()->path_, dataset()->hash_);
|
||||
auto hash_dir =
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_);
|
||||
bool metadata_file_exists;
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(hash_dir, &metadata,
|
||||
&metadata_file_exists));
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
ctx->env(), hash_dir, &metadata, &metadata_file_exists));
|
||||
|
||||
auto run_dir = io::JoinPath(hash_dir, metadata.run_id());
|
||||
auto run_dir = snapshot_util::RunDirectory(hash_dir, metadata.run_id());
|
||||
|
||||
std::vector<std::string> snapshot_shard_dirs;
|
||||
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
|
||||
io::JoinPath(run_dir,
|
||||
absl::StrFormat("%s%s", "*", kShardDirectorySuffix)),
|
||||
io::JoinPath(
|
||||
run_dir,
|
||||
absl::StrFormat("%s%s", "*", snapshot_util::kShardDirectorySuffix)),
|
||||
&snapshot_shard_dirs));
|
||||
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
|
||||
|
||||
@ -691,18 +574,18 @@ SnapshotDatasetV2Op::Dataset::Iterator::Writer::Writer(const Params& params)
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Writer::~Writer() {
|
||||
mutex_lock l(mu_);
|
||||
StopWriterThreads(true);
|
||||
SignalEOF(true);
|
||||
}
|
||||
|
||||
void SnapshotDatasetV2Op::Dataset::Iterator::Writer::StopWriterThreads(
|
||||
bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
void SnapshotDatasetV2Op::Dataset::Iterator::Writer::SignalEOF(bool mark_closed)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!writers_closed_) {
|
||||
// Push the end of sequence signal to each of the threads to close files.
|
||||
for (auto& writer_thread : writer_threads_) {
|
||||
writer_thread.second->StopThread();
|
||||
for (auto& writer : writers_) {
|
||||
writer.second->SignalEOF();
|
||||
}
|
||||
|
||||
writer_threads_.clear();
|
||||
writers_.clear();
|
||||
writers_closed_ = mark_closed;
|
||||
}
|
||||
}
|
||||
@ -720,10 +603,10 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
|
||||
metadata.add_dtype(output_dtype);
|
||||
}
|
||||
metadata.set_finalized(finalized);
|
||||
tstring hash_directory = HashDirectory(dataset()->path_, dataset()->hash_);
|
||||
tstring hash_directory =
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_);
|
||||
|
||||
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(hash_directory));
|
||||
return snapshot_util::WriteMetadataFile(hash_directory, &metadata);
|
||||
return snapshot_util::WriteMetadataFile(env, hash_directory, &metadata);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
|
||||
@ -737,12 +620,13 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
|
||||
std::vector<Tensor>* tensors, int64* shard_index) {
|
||||
IteratorContext* ctx, const std::vector<Tensor>& tensors,
|
||||
int64* shard_index) {
|
||||
std::vector<Tensor> output_tensors;
|
||||
|
||||
// Run the shard function
|
||||
TF_RETURN_IF_ERROR(
|
||||
instantiated_shard_func_->RunInstantiated(*tensors, &output_tensors));
|
||||
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
|
||||
ctx, tensors, &output_tensors));
|
||||
|
||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||
output_tensors[0].NumElements() != 1) {
|
||||
@ -758,7 +642,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
|
||||
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) {
|
||||
*end_of_sequence = false;
|
||||
WriterThread* current_writer_thread;
|
||||
snapshot_util::AsyncWriter* current_writer;
|
||||
|
||||
{
|
||||
std::vector<Tensor> output_tensors;
|
||||
@ -772,8 +656,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
|
||||
run_id_ = random::New64();
|
||||
|
||||
// Creates the run directory.
|
||||
run_dir_ = RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_),
|
||||
run_id_);
|
||||
run_dir_ = snapshot_util::RunDirectory(
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_),
|
||||
run_id_);
|
||||
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
|
||||
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
|
||||
}
|
||||
@ -788,27 +673,33 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
|
||||
|
||||
// Finalize metadata file when we are at the end of the iterator.
|
||||
if (*end_of_sequence) {
|
||||
StopWriterThreads(/*mark_closed=*/true);
|
||||
SignalEOF(/*mark_closed=*/true);
|
||||
TF_RETURN_IF_ERROR(writer_status_);
|
||||
return WriteMetadataFile(ctx->env(), /*finalized=*/true);
|
||||
}
|
||||
|
||||
int64 shard_index = 0;
|
||||
TF_RETURN_IF_ERROR(GetShardIndex(out_tensors, &shard_index));
|
||||
TF_RETURN_IF_ERROR(GetShardIndex(ctx, *out_tensors, &shard_index));
|
||||
|
||||
// If the index does not exist, we will start a new thread.
|
||||
if (writer_threads_.count(shard_index) == 0) {
|
||||
const tstring snapshot_shard_directory =
|
||||
SnapshotShardDirectory(run_dir_, shard_index);
|
||||
auto thread_data = std::make_unique<WriterThread>(
|
||||
this, ctx->env(), shard_index, snapshot_shard_directory,
|
||||
current_checkpoint_id_);
|
||||
writer_threads_.insert({shard_index, std::move(thread_data)});
|
||||
if (writers_.count(shard_index) == 0) {
|
||||
auto snapshot_shard_directory =
|
||||
snapshot_util::ShardDirectory(run_dir_, shard_index);
|
||||
auto writer = std::make_unique<snapshot_util::AsyncWriter>(
|
||||
ctx->env(), shard_index, snapshot_shard_directory,
|
||||
current_checkpoint_id_, dataset()->compression_, kFileFormatVersion,
|
||||
dataset()->output_dtypes(), [this](Status s) {
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
writer_status_ = s;
|
||||
}
|
||||
});
|
||||
writers_.insert({shard_index, std::move(writer)});
|
||||
}
|
||||
current_writer_thread = writer_threads_[shard_index].get();
|
||||
current_writer = writers_[shard_index].get();
|
||||
}
|
||||
|
||||
current_writer_thread->EnqueueTensors(*out_tensors);
|
||||
current_writer->Write(*out_tensors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -820,11 +711,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::SaveInternal(
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kCurrentCheckpointId),
|
||||
static_cast<int64>(current_checkpoint_id_)));
|
||||
|
||||
StopWriterThreads(/*mark_closed=*/false);
|
||||
writer_threads_.clear();
|
||||
SignalEOF(/*mark_closed=*/false);
|
||||
writers_.clear();
|
||||
current_checkpoint_id_++;
|
||||
|
||||
return SaveInput(ctx, writer, input_impl_);
|
||||
}
|
||||
|
||||
@ -839,8 +728,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
|
||||
¤t_checkpoint_id));
|
||||
|
||||
run_id_ = static_cast<uint64>(run_id_signed);
|
||||
run_dir_ =
|
||||
RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_), run_id_);
|
||||
run_dir_ = snapshot_util::RunDirectory(
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_),
|
||||
run_id_);
|
||||
current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
|
||||
|
||||
return RestoreInput(ctx, reader, input_impl_);
|
||||
@ -1056,7 +946,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
|
||||
|
||||
Status dump_status =
|
||||
snapshot_util::DumpDatasetGraph(path, hash, &graph_def);
|
||||
snapshot_util::DumpDatasetGraph(ctx->env(), path, hash, &graph_def);
|
||||
if (!dump_status.ok()) {
|
||||
LOG(WARNING) << "Unable to write graphdef to disk, error: "
|
||||
<< dump_status.ToString();
|
||||
@ -1244,7 +1134,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
hash_dir_, &metadata, &file_exists));
|
||||
ctx->env(), hash_dir_, &metadata, &file_exists));
|
||||
TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
|
||||
dataset()->mode_, file_exists, &metadata,
|
||||
dataset()->pending_snapshot_expiry_seconds_, &state_));
|
||||
@ -1284,8 +1174,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(hash_dir_, &metadata,
|
||||
&file_exists));
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
ctx->env(), hash_dir_, &metadata, &file_exists));
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
|
||||
VLOG(2) << "Restoring Snapshot iterator: " << state_;
|
||||
return RestoreInput(ctx, reader, iterator_);
|
||||
@ -1600,7 +1490,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status ReadFile(Env* env, const string& filename) {
|
||||
std::unique_ptr<snapshot_util::Reader> reader;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::Reader::Create(
|
||||
Env::Default(), filename, dataset()->compression_, version_,
|
||||
env, filename, dataset()->compression_, version_,
|
||||
dataset()->output_dtypes(), &reader));
|
||||
while (true) {
|
||||
// Wait for a slot in the buffer.
|
||||
@ -1813,8 +1703,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
// If we're restoring then the directory already exists and we
|
||||
// don't want to overwrite the snapshot metadata file.
|
||||
if (!is_restored_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
Env::Default()->RecursivelyCreateDir(run_dir_));
|
||||
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
metadata.set_creation_timestamp(EnvTime::NowMicros());
|
||||
metadata.set_graph_hash(dataset()->graph_hash_);
|
||||
@ -1824,8 +1713,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
metadata.add_dtype(output_dtype);
|
||||
}
|
||||
metadata.set_finalized(false);
|
||||
TF_RETURN_IF_ERROR(
|
||||
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
|
||||
TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile(
|
||||
ctx->env(), hash_dir_, &metadata));
|
||||
}
|
||||
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
|
||||
++num_active_threads_;
|
||||
@ -2069,11 +1958,6 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferElement {
|
||||
std::vector<Tensor> value;
|
||||
bool end_of_sequence;
|
||||
};
|
||||
|
||||
string GetSnapshotFilename() {
|
||||
mutex_lock l(mu_);
|
||||
string snapshot_data_filename = io::JoinPath(
|
||||
@ -2085,7 +1969,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status FillBuffer(IteratorContext* ctx) TF_LOCKS_EXCLUDED(mu_) {
|
||||
BufferElement elem;
|
||||
snapshot_util::ElementOrEOF elem;
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, &elem.value, &elem.end_of_sequence));
|
||||
|
||||
@ -2119,16 +2003,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
"maximum size: ", dataset()->writer_buffer_size_);
|
||||
}
|
||||
|
||||
BufferElement elem_copy = next_elem_;
|
||||
snapshot_util::ElementOrEOF elem_copy = next_elem_;
|
||||
buffer_.push_back(elem_copy);
|
||||
cond_var_.notify_all();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ProcessOneElement(int64* bytes_written,
|
||||
Status ProcessOneElement(Env* env, int64* bytes_written,
|
||||
string* snapshot_data_filename,
|
||||
std::unique_ptr<snapshot_util::Writer>* writer,
|
||||
bool* end_of_processing, Env* env) {
|
||||
bool* end_of_processing) {
|
||||
profiler::TraceMe activity(
|
||||
[&]() {
|
||||
return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
|
||||
@ -2138,7 +2022,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
*end_of_processing = false;
|
||||
bool produced_elem = false;
|
||||
bool snapshot_failed = false;
|
||||
BufferElement elem;
|
||||
snapshot_util::ElementOrEOF elem;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// Wait for buffer to not be empty.
|
||||
@ -2175,7 +2059,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
bool should_close;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShouldCloseWriter(*snapshot_data_filename, *bytes_written,
|
||||
ShouldCloseWriter(env, *snapshot_data_filename, *bytes_written,
|
||||
(*writer).get(), &should_close));
|
||||
if (should_close) {
|
||||
// If we exceed the shard size, we get a new file and reset.
|
||||
@ -2198,12 +2082,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
hash_dir_, &metadata, &file_exists));
|
||||
env, hash_dir_, &metadata, &file_exists));
|
||||
|
||||
if (metadata.run_id() == run_id_) {
|
||||
metadata.set_finalized(true);
|
||||
TF_RETURN_IF_ERROR(
|
||||
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
|
||||
TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile(
|
||||
env, hash_dir_, &metadata));
|
||||
} else {
|
||||
// TODO(frankchn): We lost the race, remove all snapshots.
|
||||
}
|
||||
@ -2240,8 +2124,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
bool end_of_processing = false;
|
||||
while (!end_of_processing) {
|
||||
Status s =
|
||||
ProcessOneElement(&bytes_written, &snapshot_data_filename,
|
||||
&writer, &end_of_processing, env);
|
||||
ProcessOneElement(env, &bytes_written, &snapshot_data_filename,
|
||||
&writer, &end_of_processing);
|
||||
if (!s.ok()) {
|
||||
LOG(INFO) << "Error while writing snapshot data to disk: "
|
||||
<< s.ToString();
|
||||
@ -2255,7 +2139,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
Status ShouldCloseWriter(const string& filename, uint64 bytes_written,
|
||||
Status ShouldCloseWriter(Env* env, const string& filename,
|
||||
uint64 bytes_written,
|
||||
snapshot_util::Writer* writer,
|
||||
bool* should_close) {
|
||||
// If the compression ratio has been estimated, use it to decide
|
||||
@ -2280,7 +2165,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Make sure that all bytes are written out.
|
||||
TF_RETURN_IF_ERROR(writer->Sync());
|
||||
uint64 file_size;
|
||||
TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size));
|
||||
TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
|
||||
mutex_lock l(mu_);
|
||||
compression_ratio_ = static_cast<double>(bytes_written) /
|
||||
static_cast<double>(file_size);
|
||||
@ -2300,7 +2185,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
// 5. By the background threads when they finish.
|
||||
condition_variable cond_var_;
|
||||
|
||||
BufferElement next_elem_ TF_GUARDED_BY(mu_);
|
||||
snapshot_util::ElementOrEOF next_elem_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
|
||||
const string hash_dir_;
|
||||
@ -2313,7 +2198,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
int64 time_spent_micros_ TF_GUARDED_BY(mu_) = 0;
|
||||
int64 bytes_produced_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
std::deque<BufferElement> buffer_ TF_GUARDED_BY(mu_);
|
||||
std::deque<snapshot_util::ElementOrEOF> buffer_ TF_GUARDED_BY(mu_);
|
||||
bool snapshot_failed_ TF_GUARDED_BY(mu_) = false;
|
||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||
bool first_call_ TF_GUARDED_BY(mu_) = true;
|
||||
|
@ -49,10 +49,27 @@ namespace snapshot_util {
|
||||
/* static */ constexpr const int64
|
||||
CustomReader::kSnappyReaderOutputBufferSizeBytes;
|
||||
|
||||
std::string GetCurrentCheckpointFile(const std::string& shard_directory,
|
||||
const uint64 current_checkpoint_id) {
|
||||
std::string HashDirectory(const std::string& path, uint64 hash) {
|
||||
return io::JoinPath(path, absl::StrFormat("%d", hash));
|
||||
}
|
||||
|
||||
std::string RunDirectory(const std::string& hash_directory, uint64 run_id) {
|
||||
return RunDirectory(hash_directory, absl::StrFormat("%d", run_id));
|
||||
}
|
||||
|
||||
std::string RunDirectory(const std::string& hash_directory,
|
||||
const std::string& run_id) {
|
||||
return io::JoinPath(hash_directory, run_id);
|
||||
}
|
||||
|
||||
std::string ShardDirectory(const std::string& run_directory, int64 shard_id) {
|
||||
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", shard_id,
|
||||
kShardDirectorySuffix));
|
||||
}
|
||||
std::string GetCheckpointFileName(const std::string& shard_directory,
|
||||
uint64 checkpoint_id) {
|
||||
return io::JoinPath(shard_directory,
|
||||
absl::StrFormat("%08d.snapshot", current_checkpoint_id));
|
||||
absl::StrFormat("%08d.snapshot", checkpoint_id));
|
||||
}
|
||||
|
||||
Status Writer::Create(Env* env, const std::string& filename,
|
||||
@ -357,13 +374,6 @@ class Reader::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string shard_dir_;
|
||||
const std::string compression_;
|
||||
const int64 version_;
|
||||
const DataTypeVector dtypes_;
|
||||
const std::vector<PartialTensorShape> shapes_;
|
||||
const int64 start_index_;
|
||||
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
@ -417,16 +427,9 @@ class Reader::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<Reader> reader_;
|
||||
|
||||
// Stores the id current checkpoint file that we are in the process of
|
||||
// reading (e.g. if the file is currently 00000001.snapshot, then this will
|
||||
// be 1).
|
||||
uint64 current_checkpoint_id_;
|
||||
|
||||
std::string GetCurrentFilename() {
|
||||
return GetCurrentCheckpointFile(dataset()->shard_dir_,
|
||||
current_checkpoint_id_);
|
||||
return GetCheckpointFileName(dataset()->shard_dir_,
|
||||
current_checkpoint_id_);
|
||||
}
|
||||
|
||||
Status AdvanceToNextFile(Env* env) {
|
||||
@ -435,7 +438,21 @@ class Reader::Dataset : public DatasetBase {
|
||||
return Reader::Create(env, GetCurrentFilename(), dataset()->compression_,
|
||||
dataset()->version_, dataset()->dtypes_, &reader_);
|
||||
}
|
||||
|
||||
std::unique_ptr<Reader> reader_;
|
||||
|
||||
// Stores the id current checkpoint file that we are in the process of
|
||||
// reading (e.g. if the file is currently 00000001.snapshot, then this will
|
||||
// be 1).
|
||||
uint64 current_checkpoint_id_;
|
||||
};
|
||||
|
||||
const std::string shard_dir_;
|
||||
const std::string compression_;
|
||||
const int64 version_;
|
||||
const DataTypeVector dtypes_;
|
||||
const std::vector<PartialTensorShape> shapes_;
|
||||
const int64 start_index_;
|
||||
};
|
||||
|
||||
class Reader::NestedDataset : public DatasetBase {
|
||||
@ -571,7 +588,7 @@ TFRecordReader::TFRecordReader(const std::string& filename,
|
||||
dtypes_(dtypes) {}
|
||||
|
||||
Status TFRecordReader::Initialize(Env* env) {
|
||||
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
|
||||
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
|
||||
|
||||
record_reader_ = absl::make_unique<io::RecordReader>(
|
||||
file_.get(), io::RecordReaderOptions::CreateRecordReaderOptions(
|
||||
@ -607,7 +624,7 @@ CustomReader::CustomReader(const std::string& filename,
|
||||
dtypes_(dtypes) {}
|
||||
|
||||
Status CustomReader::Initialize(Env* env) {
|
||||
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename_, &file_));
|
||||
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
|
||||
input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
|
||||
|
||||
#if defined(IS_SLIM_BUILD)
|
||||
@ -807,31 +824,31 @@ Status CustomReader::ReadRecord(absl::Cord* record) {
|
||||
}
|
||||
#endif
|
||||
|
||||
Status WriteMetadataFile(const string& hash_dir,
|
||||
Status WriteMetadataFile(Env* env, const string& dir,
|
||||
const experimental::SnapshotMetadataRecord* metadata) {
|
||||
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
|
||||
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
|
||||
string metadata_filename = io::JoinPath(dir, kMetadataFilename);
|
||||
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir));
|
||||
std::string tmp_filename =
|
||||
absl::StrCat(metadata_filename, "-tmp-", random::New64());
|
||||
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, *metadata));
|
||||
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
|
||||
TF_RETURN_IF_ERROR(WriteBinaryProto(env, tmp_filename, *metadata));
|
||||
return env->RenameFile(tmp_filename, metadata_filename);
|
||||
}
|
||||
|
||||
Status ReadMetadataFile(const string& hash_dir,
|
||||
Status ReadMetadataFile(Env* env, const string& dir,
|
||||
experimental::SnapshotMetadataRecord* metadata,
|
||||
bool* file_exists) {
|
||||
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
|
||||
Status s = Env::Default()->FileExists(metadata_filename);
|
||||
string metadata_filename = io::JoinPath(dir, kMetadataFilename);
|
||||
Status s = env->FileExists(metadata_filename);
|
||||
*file_exists = s.ok();
|
||||
|
||||
if (*file_exists) {
|
||||
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
|
||||
return ReadBinaryProto(env, metadata_filename, metadata);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
||||
Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
|
||||
const GraphDef* graph) {
|
||||
std::string hash_hex =
|
||||
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
|
||||
@ -839,8 +856,8 @@ Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
||||
io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt"));
|
||||
|
||||
LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file;
|
||||
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(path));
|
||||
return WriteTextProto(Env::Default(), graph_file, *graph);
|
||||
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(path));
|
||||
return WriteTextProto(env, graph_file, *graph);
|
||||
}
|
||||
|
||||
Status DetermineOpState(const std::string& mode_string, bool file_exists,
|
||||
@ -893,6 +910,68 @@ Status DetermineOpState(const std::string& mode_string, bool file_exists,
|
||||
}
|
||||
}
|
||||
|
||||
AsyncWriter::AsyncWriter(Env* env, int64 file_index,
|
||||
const std::string& shard_directory,
|
||||
uint64 checkpoint_id, const std::string& compression,
|
||||
int64 version, const DataTypeVector& output_types,
|
||||
std::function<void(Status)> done) {
|
||||
thread_ = absl::WrapUnique(env->StartThread(
|
||||
ThreadOptions(), absl::StrCat("writer_thread_", file_index),
|
||||
[this, env, shard_directory, checkpoint_id, compression, version,
|
||||
&output_types, done = std::move(done)] {
|
||||
done(WriterThread(env, shard_directory, checkpoint_id, compression,
|
||||
version, output_types));
|
||||
}));
|
||||
}
|
||||
|
||||
void AsyncWriter::Write(const std::vector<Tensor>& tensors) {
|
||||
mutex_lock l(mu_);
|
||||
ElementOrEOF element;
|
||||
element.value = tensors;
|
||||
deque_.push_back(std::move(element));
|
||||
}
|
||||
|
||||
void AsyncWriter::SignalEOF() {
|
||||
mutex_lock l(mu_);
|
||||
ElementOrEOF be;
|
||||
be.end_of_sequence = true;
|
||||
deque_.push_back(std::move(be));
|
||||
}
|
||||
|
||||
void AsyncWriter::Consume(ElementOrEOF* be) {
|
||||
mutex_lock l(mu_);
|
||||
mu_.Await(tensorflow::Condition(this, &AsyncWriter::ElementAvailable));
|
||||
*be = deque_.front();
|
||||
deque_.pop_front();
|
||||
}
|
||||
|
||||
bool AsyncWriter::ElementAvailable() { return !deque_.empty(); }
|
||||
|
||||
Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory,
|
||||
uint64 checkpoint_id,
|
||||
const std::string& compression, int64 version,
|
||||
DataTypeVector output_types) {
|
||||
std::unique_ptr<snapshot_util::Writer> writer;
|
||||
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
|
||||
|
||||
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
|
||||
env, GetCheckpointFileName(shard_directory, checkpoint_id), compression,
|
||||
version, std::move(output_types), &writer));
|
||||
|
||||
while (true) {
|
||||
ElementOrEOF be;
|
||||
Consume(&be);
|
||||
|
||||
if (be.end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(writer->Close());
|
||||
break;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace snapshot_util
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -48,11 +49,25 @@ constexpr char kModeAuto[] = "auto";
|
||||
constexpr char kModeWrite[] = "write";
|
||||
constexpr char kModeRead[] = "read";
|
||||
constexpr char kModePassthrough[] = "passthrough";
|
||||
constexpr char kShardDirectorySuffix[] = ".shard";
|
||||
|
||||
enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
|
||||
|
||||
std::string GetCurrentCheckpointFile(const std::string& shard_directory,
|
||||
const uint64 current_checkpoint_id);
|
||||
// Returns the name of the "hash" directory for the given base path and hash ID.
|
||||
std::string HashDirectory(const std::string& path, uint64 hash);
|
||||
|
||||
// Returns the name of the "run" directory for the given base path and run ID.
|
||||
std::string RunDirectory(const std::string& hash_directory, uint64 run_id);
|
||||
std::string RunDirectory(const std::string& hash_directory,
|
||||
const std::string& run_id);
|
||||
|
||||
// Returns the name of the "shard" directory for the given base path and shard
|
||||
// ID.
|
||||
std::string ShardDirectory(const std::string& run_directory, int64 shard_id);
|
||||
|
||||
// Returns the checkpoint file name for the given directory and checkpoint ID.
|
||||
std::string GetCheckpointFileName(const std::string& shard_directory,
|
||||
const uint64 checkpoint_id);
|
||||
|
||||
// This is a interface class that exposes snapshot writing functionality.
|
||||
class Writer {
|
||||
@ -265,14 +280,17 @@ class CustomReader : public Reader {
|
||||
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
|
||||
};
|
||||
|
||||
Status WriteMetadataFile(const string& hash_dir,
|
||||
// Writes snapshot metadata to the given directory.
|
||||
Status WriteMetadataFile(Env* env, const string& dir,
|
||||
const experimental::SnapshotMetadataRecord* metadata);
|
||||
|
||||
Status ReadMetadataFile(const string& hash_dir,
|
||||
// Reads snapshot metadata from the given directory.
|
||||
Status ReadMetadataFile(Env* env, const string& dir,
|
||||
experimental::SnapshotMetadataRecord* metadata,
|
||||
bool* file_exists);
|
||||
|
||||
Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
||||
// Writes a dataset graph to the given directory.
|
||||
Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
|
||||
const GraphDef* graph);
|
||||
|
||||
Status DetermineOpState(const std::string& mode_string, bool file_exists,
|
||||
@ -280,6 +298,59 @@ Status DetermineOpState(const std::string& mode_string, bool file_exists,
|
||||
const uint64 pending_snapshot_expiry_seconds,
|
||||
Mode* mode);
|
||||
|
||||
// Represents a dataset element or EOF.
|
||||
struct ElementOrEOF {
|
||||
std::vector<Tensor> value;
|
||||
bool end_of_sequence = false;
|
||||
};
|
||||
|
||||
// AsyncWriter provides API for asynchronously writing dataset elements
|
||||
// (each represented as a vector of tensors) to a file.
|
||||
//
|
||||
// The expected use of this API is:
|
||||
//
|
||||
// std::unique_ptr<AsyncWriter> writer = absl_make_unique<AsyncWriter>(...);
|
||||
//
|
||||
// while (data_available()) {
|
||||
// std::vector<Tensor> data = read_data()
|
||||
// writer->Write(data);
|
||||
// }
|
||||
// writer->SignalEOF();
|
||||
// writer = nullptr; // This will block until writes are flushed.
|
||||
class AsyncWriter {
|
||||
public:
|
||||
explicit AsyncWriter(Env* env, int64 file_index,
|
||||
const std::string& shard_directory, uint64 checkpoint_id,
|
||||
const std::string& compression, int64 version,
|
||||
const DataTypeVector& output_types,
|
||||
std::function<void(Status)> done);
|
||||
|
||||
// Writes the given tensors. The method is non-blocking and returns without
|
||||
// waiting for the element to be written.
|
||||
void Write(const std::vector<Tensor>& tensors) TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Signals the end of input. The method is non-blocking and returns without
|
||||
// waiting for the writer to be closed.
|
||||
void SignalEOF() TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_);
|
||||
bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status WriterThread(Env* env, const std::string& shard_directory,
|
||||
uint64 checkpoint_id, const std::string& compression,
|
||||
int64 version, DataTypeVector output_types);
|
||||
|
||||
mutex mu_;
|
||||
std::deque<ElementOrEOF> deque_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// This has to be last. During destruction, we need to make sure that the
|
||||
// Thread object is destroyed first as its destructor blocks on thread
|
||||
// completion. If there are other member variables after this, they may get
|
||||
// destroyed first before the thread finishes, potentially causing the
|
||||
// thread to access invalid memory.
|
||||
std::unique_ptr<Thread> thread_;
|
||||
};
|
||||
|
||||
} // namespace snapshot_util
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user