Add implementation for snapshot dataset v2.

PiperOrigin-RevId: 314258033
Change-Id: I6151fdc646a297090de6eeeb3254a556ae9d13bc
This commit is contained in:
Frank Chen 2020-06-01 20:34:49 -07:00 committed by TensorFlower Gardener
parent 70387ab55b
commit a3f393bc95
15 changed files with 1592 additions and 7 deletions

View File

@ -0,0 +1,41 @@
op {
graph_op_name: "SnapshotDatasetV2"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the input dataset.
END
}
in_arg {
name: "path"
description: <<END
The path we should write snapshots to / read snapshots from.
END
}
attr {
name: "compression"
description: <<END
The type of compression to be applied to the saved snapshot files.
END
}
attr {
name: "reader_func"
description: <<END
Optional. A function to control how to read data from snapshot shards.
END
}
attr {
name: "shard_func"
description: <<END
Optional. A function to control how to shard data when writing a snapshot.
END
}
summary: "Creates a dataset that will write to / read from a snapshot."
description: <<END
This dataset attempts to determine whether a valid snapshot exists at the
`snapshot_path`, and reads from the snapshot in lieu of using `input_dataset`.
If not, it will run the preprocessing pipeline as usual, and write out a
snapshot of the data processed for future use.
END
}

View File

@ -544,6 +544,7 @@ cc_library(
tf_kernel_library(
name = "snapshot_dataset_op",
srcs = ["snapshot_dataset_op.cc"],
hdrs = ["snapshot_dataset_op.h"],
deps = [
":snapshot_util",
"//tensorflow/core:core_cpu_internal",
@ -552,10 +553,16 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/framework:op_requires",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/kernels/data:captured_function",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils",
"//tensorflow/core/platform:platform_port",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
],
)

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h"
#include <random>
#include "absl/time/clock.h"
@ -60,6 +62,875 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace experimental {
// ==== Snapshot Implementation ====
namespace {
/* The current snapshot on-disk layout is as follows:
* /user/specified/path/
* - graphhash1/
* - snapshot.metadata // metadata file
* - run1/
* - 00000000.shard/ // shard index
* // new checkpoint files are created on all threads at once, either
* // when a file gets too big, or when a TF checkpoint happens.
* - 00000000.snapshot // checkpoint file 0
* - 00000001.snapshot // checkpoint file 1
* - ...
* - 00000001.shard/
* - 00000000.snapshot
* - 00000001.snapshot
* - ...
* - 00000002.shard/
* - 00000000.snapshot
* - 00000001.snapshot
* - ...
* ...
* - run2/
* ...
* - graphhash2/
* ...
* - graphhash3/
* ...
*/
constexpr const char* const kShardDirectorySuffix = ".shard";
inline tstring HashDirectory(const tstring& path, const uint64 hash) {
return io::JoinPath(path, absl::StrFormat("%d", hash));
}
inline tstring RunDirectory(const tstring& hash_directory,
const uint64 run_id) {
return io::JoinPath(hash_directory, absl::StrFormat("%d", run_id));
}
inline tstring SnapshotShardDirectory(const tstring& run_directory,
const int64 snapshot_index) {
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", snapshot_index,
kShardDirectorySuffix));
}
} // namespace
class SnapshotDatasetV2Op::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
const std::string& path, const std::string& compression,
std::unique_ptr<CapturedFunction> reader_func,
std::unique_ptr<CapturedFunction> shard_func);
~Dataset() override;
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override;
const DataTypeVector& output_dtypes() const override;
const std::vector<PartialTensorShape>& output_shapes() const override;
string DebugString() const override;
int64 Cardinality() const override;
Status CheckExternalState() const override;
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override;
private:
const DatasetBase* input_;
const uint64 hash_;
const tstring path_;
const std::string compression_;
std::unique_ptr<CapturedFunction> reader_func_;
std::unique_ptr<CapturedFunction> shard_func_;
class Iterator;
};
class SnapshotDatasetV2Op::Dataset::Iterator : public DatasetIterator<Dataset> {
public:
static constexpr const char* const kIteratorMode = "iterator_mode";
static constexpr const char* const kIndex = "index";
static constexpr const char* const kGraphHashDirectory =
"graph_hash_directory";
explicit Iterator(const Params& params);
Status Initialize(IteratorContext* ctx) override;
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override;
protected:
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override;
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override;
private:
Status InitializeIterator(IteratorContext* ctx, IteratorStateReader* reader);
int64 index_ TF_GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
snapshot_util::Mode mode_ TF_GUARDED_BY(mu_);
const std::string hash_dir_;
mutex mu_;
class Reader;
class Writer;
class Passthrough;
};
class SnapshotDatasetV2Op::Dataset::Iterator::Reader
: public DatasetIterator<Dataset> {
public:
static constexpr const char* const kIteratorName = "Reader";
explicit Reader(const Params& params, int64 start_index);
~Reader() override;
Status Initialize(IteratorContext* ctx) override;
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override;
protected:
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override;
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override;
private:
const int64 start_index_;
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_);
std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
TF_GUARDED_BY(mu_);
};
class SnapshotDatasetV2Op::Dataset::Iterator::Writer
: public DatasetIterator<Dataset> {
public:
static constexpr const char* const kIteratorName = "Writer";
static constexpr const char* const kRunId = "run_id";
static constexpr const char* const kCurrentCheckpointId =
"current_checkpoint_id";
explicit Writer(const Params& params);
~Writer() override;
Status Initialize(IteratorContext* ctx) override;
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override;
protected:
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override;
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override;
private:
struct BufferElement {
std::vector<Tensor> value;
bool end_of_sequence = false;
};
class WriterThread {
public:
explicit WriterThread(Writer* writer_iterator, Env* env, int64 file_index,
const tstring& shard_directory,
uint64 current_checkpoint_id) {
thread_ = absl::WrapUnique(env->StartThread(
ThreadOptions(), absl::StrCat("snapshot_writer_thread_", file_index),
[this, writer_iterator, env, shard_directory, current_checkpoint_id] {
RunWriterThread(writer_iterator, env, shard_directory,
current_checkpoint_id);
}));
}
void EnqueueTensors(const std::vector<Tensor>& tensors)
TF_LOCKS_EXCLUDED(deque_mu_) {
// Copy the Tensor to the deque for writing.
mutex_lock l(deque_mu_);
BufferElement be;
be.value = tensors;
deque_.push_back(std::move(be));
}
void DequeueTensors(BufferElement* be) TF_LOCKS_EXCLUDED(deque_mu_) {
mutex_lock l(deque_mu_);
deque_mu_.Await(
tensorflow::Condition(this, &WriterThread::DequeIsNotEmpty));
*be = deque_.front();
deque_.pop_front();
}
void StopThread() TF_LOCKS_EXCLUDED(deque_mu_) {
mutex_lock l(deque_mu_);
BufferElement be;
be.end_of_sequence = true;
deque_.push_back(std::move(be));
}
void RunWriterThread(Writer* writer_iterator, Env* env,
const tstring& shard_directory,
uint64 current_checkpoint_id) {
Status s = WriterThreadFn(writer_iterator, env, shard_directory,
current_checkpoint_id);
if (!s.ok()) {
mutex_lock l(writer_iterator->mu_);
writer_iterator->writer_status_ = s;
}
}
private:
bool DequeIsNotEmpty() TF_EXCLUSIVE_LOCKS_REQUIRED(deque_mu_) {
return !deque_.empty();
}
Status WriterThreadFn(Writer* writer_iterator, Env* env,
const tstring& shard_directory,
uint64 current_checkpoint_id) {
std::unique_ptr<snapshot_util::Writer> writer;
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
env,
snapshot_util::GetCurrentCheckpointFile(shard_directory,
current_checkpoint_id),
writer_iterator->dataset()->compression_, kSnapshotFileFormatVersion,
writer_iterator->dataset()->output_dtypes(), &writer));
while (true) {
BufferElement be;
DequeueTensors(&be);
if (be.end_of_sequence) {
TF_RETURN_IF_ERROR(writer->Close());
break;
}
TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
}
return Status::OK();
}
// If both the writer `mu_` and this `deque_mu_` needs to be acquired, the
// writer `mu_` must be acquired first.
mutex deque_mu_;
std::deque<BufferElement> deque_ TF_GUARDED_BY(deque_mu_);
// This has to be last. During destruction, we need to make sure that
// thread_ is destroyed first as the thread destructor blocks on thread
// completion. If there are other member variables after this, they may get
// destroyed first before the thread finishes, potentially causing the
// thread to access invalid memory.
std::unique_ptr<Thread> thread_;
};
Status GetShardIndex(std::vector<Tensor>* tensors, int64* shard_index)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status WriteMetadataFile(Env* env, bool finalized)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
void StopWriterThreads(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<int64, std::unique_ptr<WriterThread>> writer_threads_
TF_GUARDED_BY(mu_);
Status writer_status_ TF_GUARDED_BY(mu_);
bool writers_closed_ TF_GUARDED_BY(mu_);
uint64 run_id_ TF_GUARDED_BY(mu_);
tstring run_dir_ TF_GUARDED_BY(mu_);
// Stores the ID of the current checkpoint .snapshot file being read. See top
// of this file for the directory layout.
uint64 current_checkpoint_id_ TF_GUARDED_BY(mu_);
std::unique_ptr<InstantiatedCapturedFunction> instantiated_shard_func_
TF_GUARDED_BY(mu_);
};
class SnapshotDatasetV2Op::Dataset::Iterator::Passthrough
: public DatasetIterator<Dataset> {
public:
static constexpr const char* const kIteratorName = "Passthrough";
explicit Passthrough(const Params& params);
Status Initialize(IteratorContext* ctx) override;
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override;
protected:
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override;
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override;
private:
std::unique_ptr<IteratorBase> input_impl_;
};
SnapshotDatasetV2Op::Dataset::Dataset(
OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
const std::string& path, const std::string& compression,
std::unique_ptr<CapturedFunction> reader_func,
std::unique_ptr<CapturedFunction> shard_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
hash_(hash),
path_(path),
compression_(compression),
reader_func_(std::move(reader_func)),
shard_func_(std::move(shard_func)) {
input_->Ref();
}
SnapshotDatasetV2Op::Dataset::~Dataset() { input_->Unref(); }
std::unique_ptr<IteratorBase>
SnapshotDatasetV2Op::Dataset::MakeIteratorInternal(const string& prefix) const {
return absl::make_unique<Iterator>(
Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")});
}
const DataTypeVector& SnapshotDatasetV2Op::Dataset::output_dtypes() const {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>&
SnapshotDatasetV2Op::Dataset::output_shapes() const {
return input_->output_shapes();
}
string SnapshotDatasetV2Op::Dataset::DebugString() const {
return name_utils::DatasetDebugString(kDatasetType);
}
int64 SnapshotDatasetV2Op::Dataset::Cardinality() const {
return input_->Cardinality();
}
Status SnapshotDatasetV2Op::Dataset::CheckExternalState() const {
return input_->CheckExternalState();
}
Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* path = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(path_, &path));
std::vector<Node*> reader_func_other_args;
DataTypeVector reader_func_other_args_types;
TF_RETURN_IF_ERROR(reader_func_->AddToGraph(ctx, b, &reader_func_other_args,
&reader_func_other_args_types));
std::vector<Node*> shard_func_other_args;
DataTypeVector shard_func_other_args_types;
TF_RETURN_IF_ERROR(shard_func_->AddToGraph(ctx, b, &shard_func_other_args,
&shard_func_other_args_types));
AttrValue compression_attr;
b->BuildAttrValue(compression_, &compression_attr);
AttrValue reader_func_attr;
b->BuildAttrValue(reader_func_->func(), &reader_func_attr);
AttrValue shard_func_attr;
b->BuildAttrValue(shard_func_->func(), &shard_func_attr);
AttrValue reader_func_arguments_types_attr;
b->BuildAttrValue(reader_func_other_args_types,
&reader_func_arguments_types_attr);
AttrValue shard_func_arguments_types_attr;
b->BuildAttrValue(shard_func_other_args_types,
&shard_func_arguments_types_attr);
return b->AddDataset(
this,
/*inputs=*/
{std::make_pair(0, input_graph_node), std::make_pair(1, path)},
/*list_inputs=*/
{std::make_pair(2, reader_func_other_args),
std::make_pair(3, shard_func_other_args)},
/*attrs=*/
{{kCompression, compression_attr},
{kReaderFunc, reader_func_attr},
{kShardFunc, shard_func_attr},
{kReaderFuncTarguments, reader_func_arguments_types_attr},
{kShardFuncTarguments, shard_func_arguments_types_attr}},
output);
}
SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
index_(0),
hash_dir_(HashDirectory(dataset()->path_, dataset()->hash_)) {}
Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
IteratorContext* ctx) {
return ctx->env()->RecursivelyCreateDir(hash_dir_);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::SaveInternal(
SerializationContext* ctx, IteratorStateWriter* writer) {
mutex_lock l(mu_);
if (iterator_ != nullptr) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIteratorMode),
static_cast<int64>(mode_)));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kGraphHashDirectory), hash_dir_));
}
return Status::OK();
}
Status SnapshotDatasetV2Op::Dataset::Iterator::RestoreInternal(
IteratorContext* ctx, IteratorStateReader* reader) {
mutex_lock l(mu_);
if (reader->Contains(full_name(kIteratorMode))) {
TF_RETURN_IF_ERROR(InitializeIterator(ctx, reader));
return RestoreInput(ctx, reader, iterator_);
}
return Status::OK();
}
Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
mutex_lock l(mu_);
if (iterator_ == nullptr) {
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
}
// TODO(b/154341936): Explicitly stopping and starting this iterator
// should not be necessary, but the additional
// `{Reader,Writer,Passthrough}::kIteratorName` added to the prefix passed to
// `iterator_` when it was created prevents the model from identifying this
// iterator as the output of `iterator_`.
RecordStop(ctx);
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
index_++;
RecordStart(ctx);
return s;
}
Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
IteratorContext* ctx, IteratorStateReader* reader)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (reader != nullptr) {
// Check whether the computed hash directory is the same.
tstring hash_dir;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kGraphHashDirectory), &hash_dir));
if (hash_dir != hash_dir_) {
return errors::DataLoss(
"Dataset has changed while restoring from the checkpoint. Old hash "
"directory: ",
hash_dir, "; new hash directory: ", hash_dir_);
}
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
if (!file_exists) {
return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
" does not exist any more.");
}
int64 iterator_mode;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kIteratorMode), &iterator_mode));
mode_ = snapshot_util::Mode(iterator_mode);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_));
} else {
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
// `pending_snapshot_expiry_seconds` is a legacy option where we would not
// write snapshots that we think were still on-going. We decided that this
// would not be necessary as a feature for SnapshotV2, and we would always
// write a new snapshot regardless of whether someone else is currently
// writing one. Setting this to 0 ensures that all previous snapshots
// will be ignored and we will proceed to writing.
TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
/*mode_string=*/"", file_exists, &metadata,
/*pending_snapshot_expiry_seconds=*/0, &mode_));
}
switch (mode_) {
case snapshot_util::READER:
iterator_ = absl::make_unique<Reader>(
Reader::Params{dataset(),
absl::StrCat(prefix(), Reader::kIteratorName)},
index_);
break;
case snapshot_util::WRITER:
iterator_ = absl::make_unique<Writer>(Writer::Params{
dataset(), absl::StrCat(prefix(), Writer::kIteratorName)});
break;
case snapshot_util::PASSTHROUGH:
iterator_ = absl::make_unique<Passthrough>(Passthrough::Params{
dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)});
break;
}
return iterator_->Initialize(ctx);
}
SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
int64 start_index)
: DatasetIterator<Dataset>(params), start_index_(start_index) {}
SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
IteratorContext* ctx) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
tstring hash_dir = HashDirectory(dataset()->path_, dataset()->hash_);
bool metadata_file_exists;
experimental::SnapshotMetadataRecord metadata;
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(hash_dir, &metadata,
&metadata_file_exists));
auto run_dir = io::JoinPath(hash_dir, metadata.run_id());
std::vector<std::string> snapshot_shard_dirs;
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
io::JoinPath(run_dir,
absl::StrFormat("%s%s", "*", kShardDirectorySuffix)),
&snapshot_shard_dirs));
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
DatasetBase* dataset_of_snapshot_files;
TF_RETURN_IF_ERROR(snapshot_util::Reader::MakeNestedDataset(
ctx->env(), snapshot_shard_dirs, dataset()->compression_,
kSnapshotFileFormatVersion, dataset()->output_dtypes(),
dataset()->output_shapes(), start_index_, &dataset_of_snapshot_files));
Tensor input_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(dataset_of_snapshot_files,
&input_dataset_tensor));
std::vector<Tensor> reader_input;
std::vector<Tensor> reader_output;
reader_input.push_back(std::move(input_dataset_tensor));
TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
ctx, std::move(reader_input), &reader_output));
if (reader_output.size() != 1) {
return errors::InvalidArgument(
"reader_func returns more than one argument.");
}
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
// We need to take a reference here as we will use the input_ and
// its iterator.
input_->Ref();
return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::GetNextInternal(
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
mutex_lock l(mu_);
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
// We do not need to checkpoint the reader as we are rebuilding the reader
// datasets from information that is already saved by the main iterator.
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::SaveInternal(
SerializationContext* ctx, IteratorStateWriter* writer) {
return Status::OK();
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::RestoreInternal(
IteratorContext* ctx, IteratorStateReader* reader) {
return Status::OK();
}
SnapshotDatasetV2Op::Dataset::Iterator::Writer::Writer(const Params& params)
: DatasetIterator<Dataset>(params),
writers_closed_(false),
run_id_(0),
current_checkpoint_id_(0) {}
SnapshotDatasetV2Op::Dataset::Iterator::Writer::~Writer() {
mutex_lock l(mu_);
StopWriterThreads(true);
}
void SnapshotDatasetV2Op::Dataset::Iterator::Writer::StopWriterThreads(
bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!writers_closed_) {
// Push the end of sequence signal to each of the threads to close files.
for (auto& writer_thread : writer_threads_) {
writer_thread.second->StopThread();
}
writer_threads_.clear();
writers_closed_ = mark_closed;
}
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
Env* env, bool finalized) {
DCHECK(!run_dir_.empty());
experimental::SnapshotMetadataRecord metadata;
metadata.set_creation_timestamp(EnvTime::NowMicros());
metadata.set_graph_hash(absl::StrFormat("%d", dataset()->hash_));
metadata.set_run_id(absl::StrFormat("%d", run_id_));
metadata.set_version(kSnapshotFileFormatVersion);
for (const auto& output_dtype : dataset()->output_dtypes()) {
metadata.add_dtype(output_dtype);
}
metadata.set_finalized(finalized);
tstring hash_directory = HashDirectory(dataset()->path_, dataset()->hash_);
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(hash_directory));
return snapshot_util::WriteMetadataFile(hash_directory, &metadata);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
IteratorContext* ctx) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
dataset()->shard_func_->Instantiate(ctx, &instantiated_shard_func_));
return dataset()->input_->MakeIterator(
ctx, this, strings::StrCat(prefix(), "::WriterIterator"), &input_impl_);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
std::vector<Tensor>* tensors, int64* shard_index) {
std::vector<Tensor> output_tensors;
// Run the shard function
TF_RETURN_IF_ERROR(
instantiated_shard_func_->RunInstantiated(*tensors, &output_tensors));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) {
return errors::InvalidArgument("`shard_func` must return a scalar int64.");
}
// Create writable files if we see an index bigger than our current files.
*shard_index = output_tensors[0].flat<int64>()(0);
return Status::OK();
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
*end_of_sequence = false;
WriterThread* current_writer_thread;
{
std::vector<Tensor> output_tensors;
mutex_lock l(mu_);
// We initialize late here because restoring from checkpoint comes after the
// the Initialize call. We cannot initialize within Initialize() because
// we cannot determine whether we should overwrite an existing metadata
// file or not before `RestoreInternal` is potentially called.
if (run_dir_.empty()) {
run_id_ = random::New64();
// Creates the run directory.
run_dir_ = RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_),
run_id_);
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
}
// Writers have either encountered an error or are closed.
if (!writer_status_.ok() || writers_closed_) {
*end_of_sequence = true;
return writer_status_;
}
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
// Finalize metadata file when we are at the end of the iterator.
if (*end_of_sequence) {
StopWriterThreads(/*mark_closed=*/true);
TF_RETURN_IF_ERROR(writer_status_);
return WriteMetadataFile(ctx->env(), /*finalized=*/true);
}
int64 shard_index = 0;
TF_RETURN_IF_ERROR(GetShardIndex(out_tensors, &shard_index));
// If the index does not exist, we will start a new thread.
if (writer_threads_.count(shard_index) == 0) {
const tstring snapshot_shard_directory =
SnapshotShardDirectory(run_dir_, shard_index);
auto thread_data = std::make_unique<WriterThread>(
this, ctx->env(), shard_index, snapshot_shard_directory,
current_checkpoint_id_);
writer_threads_.insert({shard_index, std::move(thread_data)});
}
current_writer_thread = writer_threads_[shard_index].get();
}
current_writer_thread->EnqueueTensors(*out_tensors);
return Status::OK();
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::SaveInternal(
SerializationContext* ctx, IteratorStateWriter* writer) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kRunId), static_cast<int64>(run_id_)));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kCurrentCheckpointId),
static_cast<int64>(current_checkpoint_id_)));
StopWriterThreads(/*mark_closed=*/false);
writer_threads_.clear();
current_checkpoint_id_++;
return SaveInput(ctx, writer, input_impl_);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
IteratorContext* ctx, IteratorStateReader* reader) {
mutex_lock l(mu_);
int64 run_id_signed;
int64 current_checkpoint_id;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_signed));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointId),
&current_checkpoint_id));
run_id_ = static_cast<uint64>(run_id_signed);
run_dir_ =
RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_), run_id_);
current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
return RestoreInput(ctx, reader, input_impl_);
}
SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Passthrough(
const Params& params)
: DatasetIterator<Dataset>(params) {}
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Initialize(
IteratorContext* ctx) {
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::GetNextInternal(
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::SaveInternal(
SerializationContext* ctx, IteratorStateWriter* writer) {
return SaveInput(ctx, writer, input_impl_);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::RestoreInternal(
IteratorContext* ctx, IteratorStateReader* reader) {
return RestoreInput(ctx, reader, input_impl_);
}
SnapshotDatasetV2Op::SnapshotDatasetV2Op(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
FunctionMetadata::Params reader_params;
FunctionMetadata::Params shard_params;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, reader_params,
&reader_func_metadata_));
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, shard_params,
&shard_func_metadata_));
}
void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) {
tstring path;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
// Computes the hash of the preceding items in the graph.
uint64 graph_hash;
GraphDef graph_def;
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.external_state_policy =
SerializationContext::ExternalStatePolicy::kIgnore;
OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def));
OP_REQUIRES_OK(ctx, HashGraph(graph_def, &graph_hash));
std::unique_ptr<CapturedFunction> reader_func;
OP_REQUIRES_OK(ctx,
CapturedFunction::Create(ctx, reader_func_metadata_,
kReaderFuncOtherArgs, &reader_func));
std::unique_ptr<CapturedFunction> shard_func;
OP_REQUIRES_OK(ctx,
CapturedFunction::Create(ctx, shard_func_metadata_,
kShardFuncOtherArgs, &shard_func));
*output = new SnapshotDatasetV2Op::Dataset(
ctx, input, graph_hash, path, compression_, std::move(reader_func),
std::move(shard_func));
}
namespace {
REGISTER_KERNEL_BUILDER(Name("SnapshotDatasetV2").Device(DEVICE_CPU),
SnapshotDatasetV2Op);
} // namespace
// ==== Legacy Snapshot Implementation (Deprecated) ====
namespace {
// Defaults to 10 GiB per shard.

View File

@ -0,0 +1,83 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace data {
namespace experimental {
const int64 kSnapshotFileFormatVersion = 1;
class SnapshotDatasetV2Op : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "Snapshot";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kCompression = "compression";
static constexpr const char* const kReaderFunc = "reader_func";
static constexpr const char* const kShardFunc = "shard_func";
static constexpr const char* const kReaderFuncOtherArgs =
"reader_func_other_args";
static constexpr const char* const kShardFuncOtherArgs =
"shard_func_other_args";
static constexpr const char* const kReaderFuncTarguments =
"Treader_func_args";
static constexpr const char* const kShardFuncTarguments = "Tshard_func_args";
explicit SnapshotDatasetV2Op(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override;
private:
class Dataset;
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
std::string compression_;
std::shared_ptr<FunctionMetadata> reader_func_metadata_;
std::shared_ptr<FunctionMetadata> shard_func_metadata_;
};
} // namespace experimental
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_

View File

@ -899,6 +899,26 @@ REGISTER_OP("SnapshotDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("SnapshotDatasetV2")
.Input("input_dataset: variant")
.Input("path: string")
.Input("reader_func_other_args: Treader_func_args")
.Input("shard_func_other_args: Tshard_func_args")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("compression: string = ''")
.Attr("reader_func: func")
.Attr("shard_func: func")
.Attr("Treader_func_args: list(type) >= 0")
.Attr("Tshard_func_args: list(type) >= 0")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `path` should be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("SqlDataset")
.Input("driver_name: string")
.Input("data_source_name: string")

View File

@ -75,6 +75,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@sample_from_datasets
@@scan
@@shuffle_and_repeat
@@snapshot
@@take_while
@@to_variant
@@unbatch
@ -128,6 +129,7 @@ from tensorflow.python.data.experimental.ops.readers import SqlDataset
from tensorflow.python.data.experimental.ops.resampling import rejection_resample
from tensorflow.python.data.experimental.ops.scan_ops import scan
from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.python.data.experimental.ops.snapshot import snapshot
from tensorflow.python.data.experimental.ops.stats_aggregator import StatsAggregator
from tensorflow.python.data.experimental.ops.stats_ops import bytes_produced_stats
from tensorflow.python.data.experimental.ops.stats_ops import latency_stats

View File

@ -753,9 +753,12 @@ tf_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:string_ops",
"//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/experimental/ops:snapshot",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/data/util:nest",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -199,7 +199,9 @@ tf_py_test(
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -34,6 +34,97 @@ class SnapshotDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase,
parameterized.TestCase):
def _build_snapshot_dataset(self, repeat=False):
def ds_fn():
self._snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
if not os.path.exists(self._snapshot_dir):
os.mkdir(self._snapshot_dir)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
if repeat:
dataset = dataset.repeat(2)
return dataset
return ds_fn
@combinations.generate(test_base.default_test_combinations())
def testCheckpointBeforeEpochEndNoRepeat(self):
ds_fn = self._build_snapshot_dataset(repeat=False)
outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(50))
outputs.extend(
self.gen_outputs(ds_fn, [], 50, ckpt_saved=True, verify_exhausted=True))
self.assertSequenceEqual(outputs, range(100))
@combinations.generate(test_base.default_test_combinations())
def testCheckpointBeforeOneEpochWithReading(self):
ds_fn = self._build_snapshot_dataset(repeat=True)
# Generate 50 entries from iterator and save checkpoint.
outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(50)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
t = self.gen_outputs(ds_fn, [], 150, ckpt_saved=True, verify_exhausted=True)
outputs.extend(t)
self.assertSequenceEqual(
outputs,
list(range(50)) + list(range(50, 100)) + list(range(100)))
@combinations.generate(test_base.default_test_combinations())
def testCheckpointBeforeOneEpochThenRunAFewSteps(self):
ds_fn = self._build_snapshot_dataset(repeat=False)
outputs = self.gen_outputs(
ds_fn, [10], 20, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, range(20))
outputs = outputs[:10]
outputs.extend(
self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True))
self.assertSequenceEqual(outputs, range(100))
@combinations.generate(test_base.default_test_combinations())
def testCheckpointAfterOneEpoch(self):
ds_fn = self._build_snapshot_dataset(repeat=True)
# Generate 110 entries from iterator and save checkpoint.
outputs = self.gen_outputs(ds_fn, [], 110, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(100)) + list(range(10)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)
outputs.extend(t)
self.assertSequenceEqual(
outputs,
list(range(100)) + list(range(10)) + list(range(10, 100)))
@combinations.generate(test_base.default_test_combinations())
def testCheckpointAfterOneEpochRunFewSteps(self):
ds_fn = self._build_snapshot_dataset(repeat=True)
# Generate 120 entries from iterator and save checkpoint at 110.
outputs = self.gen_outputs(
ds_fn, [110], 120, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, list(range(100)) + list(range(20)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
outputs = outputs[:110]
t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)
outputs.extend(t)
self.assertSequenceEqual(
outputs,
list(range(100)) + list(range(10)) + list(range(10, 100)))
class LegacySnapshotDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase,
parameterized.TestCase):
def _build_snapshot_dataset(self,
num_threads=1,
repeat=False,

View File

@ -17,11 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import shutil
import time
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import snapshot
@ -40,6 +42,285 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
def setUp(self):
super(SnapshotDatasetTest, self).setUp()
tmpdir = self.get_temp_dir()
tmpdir = os.path.join(tmpdir, "snapshot")
os.mkdir(tmpdir)
self._snapshot_dir = tmpdir
def tearDown(self):
super(SnapshotDatasetTest, self).tearDown()
shutil.rmtree(self._snapshot_dir)
def createTFRecords(self, num_files=10, num_records=100):
self._num_files = num_files
self._num_records = num_records
self._test_filenames = self._createFiles()
def removeTFRecords(self):
for filename in self._test_filenames:
os.remove(filename)
self._test_filenames = []
self._num_files = None
self._num_records = None
def assertDatasetProducesSet(self, dataset, expected):
actual = []
next_fn = self.getNext(dataset)
for _ in range(len(expected)):
elem = self.evaluate(next_fn())
actual.append(elem)
self.assertCountEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_fn())
def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
num_runs_per_fingerprint,
num_snapshot_shards_per_run):
dirlist_raw = os.listdir(directory)
dirlist = []
# Ignore the graphdef pbtxts we write for debugging purposes.
for i in range(len(dirlist_raw)):
if not dirlist_raw[i].endswith("-graph.pbtxt"):
dirlist.append(dirlist_raw[i])
self.assertLen(dirlist, num_fingerprints)
for i in range(num_fingerprints):
fingerprint_dir = os.path.join(directory, dirlist[i])
fingerprint_dir_list = sorted(os.listdir(fingerprint_dir))
self.assertLen(fingerprint_dir_list, num_runs_per_fingerprint + 1)
self.assertEqual(fingerprint_dir_list[num_runs_per_fingerprint],
"snapshot.metadata")
for j in range(num_runs_per_fingerprint):
run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j])
run_dirlist = sorted(os.listdir(run_dir))
self.assertLen(run_dirlist, num_snapshot_shards_per_run)
file_counter = 0
for filename in run_dirlist:
self.assertEqual(filename, "%08d.shard" % file_counter)
file_counter += 1
@combinations.generate(test_base.default_test_combinations())
def testCreateSnapshotDataset(self):
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3])
dataset.apply(snapshot.snapshot(self._snapshot_dir))
@combinations.generate(test_base.default_test_combinations())
def testReadSnapshotDatasetDefault(self):
self.createTFRecords()
filenames = self._test_filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 100)
]
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testReadSnapshotDatasetCustomShardFn(self):
self.createTFRecords()
filenames = self._test_filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 100)
]
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(
snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: np.int64(0)))
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=1)
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(
snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: 0))
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testReadSnapshotDatasetCustomReaderFn(self):
self.createTFRecords()
filenames = self._test_filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 100)
]
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(
snapshot.snapshot(
self._snapshot_dir,
reader_func=(
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
lambda x: x,
cycle_length=4,
num_parallel_calls=4))))
self.assertDatasetProduces(dataset, expected)
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(
snapshot.snapshot(
self._snapshot_dir,
reader_func=(
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
lambda x: x,
cycle_length=4,
num_parallel_calls=4))))
self.assertDatasetProducesSet(dataset2, expected)
@combinations.generate(test_base.default_test_combinations())
def testSnapshotDatasetInvalidShardFn(self):
dataset = dataset_ops.Dataset.range(1000)
with self.assertRaises(TypeError):
dataset = dataset.apply(
snapshot.snapshot(
self._snapshot_dir, shard_func=lambda _: "invalid_fn"))
next_fn = self.getNext(dataset)
self.evaluate(next_fn())
@combinations.generate(test_base.default_test_combinations())
def testSnapshotDatasetInvalidReaderFn(self):
dataset = dataset_ops.Dataset.range(1000)
with self.assertRaises(TypeError):
dataset = dataset.apply(
snapshot.snapshot(self._snapshot_dir, reader_func=lambda x: x + 1))
next_fn = self.getNext(dataset)
self.evaluate(next_fn())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSimple(self):
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
self.assertDatasetProduces(dataset, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetMultipleFingerprints(self):
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
self.assertDatasetProduces(dataset1, list(range(1000)))
dataset2 = dataset_ops.Dataset.range(2000)
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
self.assertDatasetProduces(dataset2, list(range(2000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=2,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self):
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
self.assertDatasetProduces(dataset1, list(range(1000)))
dataset2 = dataset_ops.Dataset.range(1000)
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
self.assertDatasetProduces(dataset2, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self):
dataset1 = dataset_ops.Dataset.range(1000)
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
next1 = self.getNext(dataset1)
for i in range(500):
self.assertEqual(i, self.evaluate(next1()))
dataset2 = dataset_ops.Dataset.range(1000)
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
next2 = self.getNext(dataset2)
for i in range(500):
self.assertEqual(i, self.evaluate(next2()))
for i in range(500, 1000):
self.assertEqual(i, self.evaluate(next1()))
self.assertEqual(i, self.evaluate(next2()))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=2,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotCustomShardFunction(self):
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.enumerate()
dataset = dataset.apply(
snapshot.snapshot(self._snapshot_dir, shard_func=lambda i, _: i % 2))
dataset = dataset.map(lambda _, elem: elem)
self.assertDatasetProduces(dataset, list(range(1000)))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=2)
@combinations.generate(test_base.default_test_combinations())
def testWriteSnapshotDatasetWithTuples(self):
dataset1 = dataset_ops.Dataset.range(0, 1000)
dataset2 = dataset_ops.Dataset.range(1000, 2000)
dataset3 = dataset_ops.Dataset.range(2000, 3000)
dataset4 = dataset_ops.Dataset.range(3000, 4000)
dataset = dataset_ops.Dataset.zip((dataset1, dataset2, dataset3, dataset4))
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
next1 = self.getNext(dataset)
for i in range(0, 1000):
self.assertEqual((i, i + 1000, i + 2000, i + 3000),
self.evaluate(next1()))
self.assertSnapshotDirectoryContains(
self._snapshot_dir,
num_fingerprints=1,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
class LegacySnapshotDatasetTest(
reader_dataset_ops_test_base.TFRecordDatasetTestBase,
parameterized.TestCase):
def setUp(self):
super(LegacySnapshotDatasetTest, self).setUp()
self.removeTFRecords()
tmpdir = self.get_temp_dir()
tmpdir = os.path.join(tmpdir, "snapshot")
@ -47,7 +328,7 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
self.snapshot_dir = tmpdir
def tearDown(self):
super(SnapshotDatasetTest, self).tearDown()
super(LegacySnapshotDatasetTest, self).tearDown()
shutil.rmtree(self.snapshot_dir)
def removeTFRecords(self):
@ -63,8 +344,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
def makeSnapshotDirectory(self):
return self.snapshot_dir
def assertSnapshotDirectoryContains(
self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files):
def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
num_runs_per_fp, num_snapshot_files):
dirlist_raw = os.listdir(directory)
dirlist = []
@ -465,8 +746,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
]),
combinations.combine(threads=2, size=[1, 2]) +
combinations.combine(threads=8, size=[1, 4, 8]))))
def testReadSnapshotBackAfterMultiThreadedWrite(
self, compression, threads, size):
def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads,
size):
self.setUpTFRecord()
filenames = self.test_filenames

View File

@ -17,12 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
COMPRESSION_GZIP = "GZIP"
COMPRESSION_SNAPPY = "SNAPPY"
@ -99,6 +103,8 @@ class _LegacySnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor)
@deprecation.deprecated(
None, "Use `tf.data.experimental.snapshot(...)` instead.")
def legacy_snapshot(path,
compression=None,
reader_path_prefix=None,
@ -186,3 +192,165 @@ def legacy_snapshot(path,
snapshot_name=snapshot_name)
return _apply_fn
class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A dataset that allows saving and re-use of already processed data."""
def __init__(self,
input_dataset,
path,
shard_func,
compression=None,
reader_func=None,
pending_snapshot_expiry_seconds=None,
use_legacy_function=False):
if reader_func is None:
reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda
lambda x: x,
cycle_length=multiprocessing.cpu_count(),
num_parallel_calls=dataset_ops.AUTOTUNE)
self._input_dataset = input_dataset
self._path = path
self._compression = compression
self._reader_func = dataset_ops.StructuredFunctionWrapper(
reader_func,
self._transformation_name() + ".reader_func",
# Dataset of datasets of input elements
input_structure=dataset_ops.DatasetSpec(
dataset_ops.DatasetSpec(input_dataset.element_spec)),
use_legacy_function=use_legacy_function)
self._shard_func = dataset_ops.StructuredFunctionWrapper(
shard_func,
self._transformation_name() + ".shard_func",
dataset=input_dataset,
use_legacy_function=use_legacy_function)
if ((not self._shard_func.output_structure.is_compatible_with(
tensor_spec.TensorSpec([], dtypes.int32))) and
(not self._shard_func.output_structure.is_compatible_with(
tensor_spec.TensorSpec([], dtypes.int64)))):
raise TypeError(
"shard_func must return a 0-dimension tensor containing an int.")
variant_tensor = ged_ops.snapshot_dataset_v2(
input_dataset._variant_tensor, # pylint: disable=protected-access
path,
self._reader_func.function.captured_inputs,
self._shard_func.function.captured_inputs,
compression=compression,
reader_func=self._reader_func.function,
shard_func=self._shard_func.function,
**self._flat_structure)
super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._reader_func, self._shard_func]
def _transformation_name(self):
return "Dataset.snapshot"
@tf_export("data.experimental.snapshot")
def snapshot(path, compression="AUTO", reader_func=None, shard_func=None):
"""API to persist the output of the input dataset.
The snapshot API allows users to transparently persist the output of their
preprocessing pipeline to disk, and materialize the pre-processed data on a
different training run.
This API enables repeated preprocessing steps to be consolidated, and allows
re-use of already processed data, trading off disk storage and network
bandwidth for freeing up more valuable CPU resources and accelerator compute
time.
https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md
has detailed design documentation of this feature.
Users can specify various options to control the behavior of snapshot,
including how snapshots are read from and written to by passing in
user-defined functions to the `reader_func` and `shard_func` parameters.
`shard_func` is a user specified function that maps input elements to snapshot
shards.
Users may want to specify this function to control how snapshot files should
be written to disk. Below is an example of how a potential shard_func could
be written.
```python
dataset = ...
dataset = dataset.enumerate()
dataset = dataset.apply(tf.data.experimental.snapshot(
shard_func=lambda x, y: x % NUM_SHARDS, ...))
dataset = dataset.map(lambda x, y: y)
```
`reader_func` is a user specified function that accepts a single argument:
(1) a Dataset of Datasets, each representing a "split" of elements of the
original dataset. The cardinality of the input dataset matches the
number of the shards specified in the `shard_func` (see above). The function
should return a Dataset of elements of the original dataset.
Users may want specify this function to control how snapshot files should be
read from disk, including the amount of shuffling and parallelism.
Here is an example of a standard reader function a user can define. This
function enables both dataset shuffling and parallel reading of datasets:
```python
def user_reader_func(datasets):
# shuffle the datasets splits
datasets = datasets.shuffle(NUM_CORES)
# read datasets in parallel and interleave their elements
return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
dataset = dataset.apply(tf.data.experimental.snapshot(
reader_func=user_reader_func))
```
By default, snapshot parallelize reads by the number of cores available on
the system, but will not attempt to shuffle the data.
Args:
path: Required. A directory to use for storing / loading the snapshot to /
from.
compression: Optional. The type of compression to apply to the snapshot
written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None.
Defaults to AUTO, which attempts to pick an appropriate compression
algorithm for the dataset.
reader_func: Optional. A function to control how to read data from snapshot
shards.
shard_func: Optional. A function to control how to shard data when writing a
snapshot.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
"""Actual dataset transformation."""
if shard_func is None:
dataset = dataset.enumerate()
dataset = _SnapshotDataset(
input_dataset=dataset,
path=path,
compression=compression,
reader_func=reader_func,
# This will not do the right thing where the graph is built on a
# different machine than the executor (e.g. Cloud TPUs).
shard_func=lambda index, _: index % multiprocessing.cpu_count())
return dataset.map(lambda _, elem: elem)
else:
return _SnapshotDataset(
input_dataset=dataset,
path=path,
compression=compression,
reader_func=reader_func,
shard_func=shard_func)
return _apply_fn

View File

@ -220,6 +220,10 @@ tf_module {
name: "shuffle_and_repeat"
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "snapshot"
argspec: "args=[\'path\', \'compression\', \'reader_func\', \'shard_func\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\'], "
}
member_method {
name: "take_while"
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"

View File

@ -3992,6 +3992,10 @@ tf_module {
name: "SnapshotDataset"
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
}
member_method {
name: "SnapshotDatasetV2"
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "SobolSample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "

View File

@ -188,6 +188,10 @@ tf_module {
name: "shuffle_and_repeat"
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "snapshot"
argspec: "args=[\'path\', \'compression\', \'reader_func\', \'shard_func\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\'], "
}
member_method {
name: "take_while"
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"

View File

@ -3992,6 +3992,10 @@ tf_module {
name: "SnapshotDataset"
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
}
member_method {
name: "SnapshotDatasetV2"
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "SobolSample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "