Add implementation for snapshot dataset v2.
PiperOrigin-RevId: 314258033 Change-Id: I6151fdc646a297090de6eeeb3254a556ae9d13bc
This commit is contained in:
parent
70387ab55b
commit
a3f393bc95
@ -0,0 +1,41 @@
|
||||
op {
|
||||
graph_op_name: "SnapshotDatasetV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "path"
|
||||
description: <<END
|
||||
The path we should write snapshots to / read snapshots from.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "compression"
|
||||
description: <<END
|
||||
The type of compression to be applied to the saved snapshot files.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "reader_func"
|
||||
description: <<END
|
||||
Optional. A function to control how to read data from snapshot shards.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shard_func"
|
||||
description: <<END
|
||||
Optional. A function to control how to shard data when writing a snapshot.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that will write to / read from a snapshot."
|
||||
description: <<END
|
||||
This dataset attempts to determine whether a valid snapshot exists at the
|
||||
`snapshot_path`, and reads from the snapshot in lieu of using `input_dataset`.
|
||||
If not, it will run the preprocessing pipeline as usual, and write out a
|
||||
snapshot of the data processed for future use.
|
||||
END
|
||||
}
|
@ -544,6 +544,7 @@ cc_library(
|
||||
tf_kernel_library(
|
||||
name = "snapshot_dataset_op",
|
||||
srcs = ["snapshot_dataset_op.cc"],
|
||||
hdrs = ["snapshot_dataset_op.h"],
|
||||
deps = [
|
||||
":snapshot_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -552,10 +553,16 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/framework:op_requires",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/kernels/data:captured_function",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
"//tensorflow/core/platform:platform_port",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "absl/time/clock.h"
|
||||
@ -60,6 +62,875 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace experimental {
|
||||
|
||||
// ==== Snapshot Implementation ====
|
||||
|
||||
namespace {
|
||||
|
||||
/* The current snapshot on-disk layout is as follows:
|
||||
* /user/specified/path/
|
||||
* - graphhash1/
|
||||
* - snapshot.metadata // metadata file
|
||||
* - run1/
|
||||
* - 00000000.shard/ // shard index
|
||||
* // new checkpoint files are created on all threads at once, either
|
||||
* // when a file gets too big, or when a TF checkpoint happens.
|
||||
* - 00000000.snapshot // checkpoint file 0
|
||||
* - 00000001.snapshot // checkpoint file 1
|
||||
* - ...
|
||||
* - 00000001.shard/
|
||||
* - 00000000.snapshot
|
||||
* - 00000001.snapshot
|
||||
* - ...
|
||||
* - 00000002.shard/
|
||||
* - 00000000.snapshot
|
||||
* - 00000001.snapshot
|
||||
* - ...
|
||||
* ...
|
||||
* - run2/
|
||||
* ...
|
||||
* - graphhash2/
|
||||
* ...
|
||||
* - graphhash3/
|
||||
* ...
|
||||
*/
|
||||
|
||||
constexpr const char* const kShardDirectorySuffix = ".shard";
|
||||
|
||||
inline tstring HashDirectory(const tstring& path, const uint64 hash) {
|
||||
return io::JoinPath(path, absl::StrFormat("%d", hash));
|
||||
}
|
||||
|
||||
inline tstring RunDirectory(const tstring& hash_directory,
|
||||
const uint64 run_id) {
|
||||
return io::JoinPath(hash_directory, absl::StrFormat("%d", run_id));
|
||||
}
|
||||
|
||||
inline tstring SnapshotShardDirectory(const tstring& run_directory,
|
||||
const int64 snapshot_index) {
|
||||
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", snapshot_index,
|
||||
kShardDirectorySuffix));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class SnapshotDatasetV2Op::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
|
||||
const std::string& path, const std::string& compression,
|
||||
std::unique_ptr<CapturedFunction> reader_func,
|
||||
std::unique_ptr<CapturedFunction> shard_func);
|
||||
|
||||
~Dataset() override;
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override;
|
||||
|
||||
const DataTypeVector& output_dtypes() const override;
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
int64 Cardinality() const override;
|
||||
|
||||
Status CheckExternalState() const override;
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override;
|
||||
|
||||
private:
|
||||
const DatasetBase* input_;
|
||||
const uint64 hash_;
|
||||
const tstring path_;
|
||||
const std::string compression_;
|
||||
|
||||
std::unique_ptr<CapturedFunction> reader_func_;
|
||||
std::unique_ptr<CapturedFunction> shard_func_;
|
||||
|
||||
class Iterator;
|
||||
};
|
||||
|
||||
class SnapshotDatasetV2Op::Dataset::Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
static constexpr const char* const kIteratorMode = "iterator_mode";
|
||||
static constexpr const char* const kIndex = "index";
|
||||
static constexpr const char* const kGraphHashDirectory =
|
||||
"graph_hash_directory";
|
||||
|
||||
explicit Iterator(const Params& params);
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override;
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override;
|
||||
|
||||
protected:
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override;
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override;
|
||||
|
||||
private:
|
||||
Status InitializeIterator(IteratorContext* ctx, IteratorStateReader* reader);
|
||||
|
||||
int64 index_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
|
||||
snapshot_util::Mode mode_ TF_GUARDED_BY(mu_);
|
||||
const std::string hash_dir_;
|
||||
|
||||
mutex mu_;
|
||||
|
||||
class Reader;
|
||||
class Writer;
|
||||
class Passthrough;
|
||||
};
|
||||
|
||||
class SnapshotDatasetV2Op::Dataset::Iterator::Reader
|
||||
: public DatasetIterator<Dataset> {
|
||||
public:
|
||||
static constexpr const char* const kIteratorName = "Reader";
|
||||
|
||||
explicit Reader(const Params& params, int64 start_index);
|
||||
|
||||
~Reader() override;
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override;
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override;
|
||||
|
||||
protected:
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override;
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override;
|
||||
|
||||
private:
|
||||
const int64 start_index_;
|
||||
|
||||
mutex mu_;
|
||||
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
|
||||
DatasetBase* input_ TF_GUARDED_BY(mu_);
|
||||
|
||||
std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
|
||||
TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
class SnapshotDatasetV2Op::Dataset::Iterator::Writer
|
||||
: public DatasetIterator<Dataset> {
|
||||
public:
|
||||
static constexpr const char* const kIteratorName = "Writer";
|
||||
static constexpr const char* const kRunId = "run_id";
|
||||
static constexpr const char* const kCurrentCheckpointId =
|
||||
"current_checkpoint_id";
|
||||
|
||||
explicit Writer(const Params& params);
|
||||
|
||||
~Writer() override;
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override;
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override;
|
||||
|
||||
protected:
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override;
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override;
|
||||
|
||||
private:
|
||||
struct BufferElement {
|
||||
std::vector<Tensor> value;
|
||||
bool end_of_sequence = false;
|
||||
};
|
||||
|
||||
class WriterThread {
|
||||
public:
|
||||
explicit WriterThread(Writer* writer_iterator, Env* env, int64 file_index,
|
||||
const tstring& shard_directory,
|
||||
uint64 current_checkpoint_id) {
|
||||
thread_ = absl::WrapUnique(env->StartThread(
|
||||
ThreadOptions(), absl::StrCat("snapshot_writer_thread_", file_index),
|
||||
[this, writer_iterator, env, shard_directory, current_checkpoint_id] {
|
||||
RunWriterThread(writer_iterator, env, shard_directory,
|
||||
current_checkpoint_id);
|
||||
}));
|
||||
}
|
||||
|
||||
void EnqueueTensors(const std::vector<Tensor>& tensors)
|
||||
TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||
// Copy the Tensor to the deque for writing.
|
||||
mutex_lock l(deque_mu_);
|
||||
BufferElement be;
|
||||
be.value = tensors;
|
||||
deque_.push_back(std::move(be));
|
||||
}
|
||||
|
||||
void DequeueTensors(BufferElement* be) TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||
mutex_lock l(deque_mu_);
|
||||
deque_mu_.Await(
|
||||
tensorflow::Condition(this, &WriterThread::DequeIsNotEmpty));
|
||||
*be = deque_.front();
|
||||
deque_.pop_front();
|
||||
}
|
||||
|
||||
void StopThread() TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||
mutex_lock l(deque_mu_);
|
||||
BufferElement be;
|
||||
be.end_of_sequence = true;
|
||||
deque_.push_back(std::move(be));
|
||||
}
|
||||
|
||||
void RunWriterThread(Writer* writer_iterator, Env* env,
|
||||
const tstring& shard_directory,
|
||||
uint64 current_checkpoint_id) {
|
||||
Status s = WriterThreadFn(writer_iterator, env, shard_directory,
|
||||
current_checkpoint_id);
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(writer_iterator->mu_);
|
||||
writer_iterator->writer_status_ = s;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool DequeIsNotEmpty() TF_EXCLUSIVE_LOCKS_REQUIRED(deque_mu_) {
|
||||
return !deque_.empty();
|
||||
}
|
||||
|
||||
Status WriterThreadFn(Writer* writer_iterator, Env* env,
|
||||
const tstring& shard_directory,
|
||||
uint64 current_checkpoint_id) {
|
||||
std::unique_ptr<snapshot_util::Writer> writer;
|
||||
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
|
||||
|
||||
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
|
||||
env,
|
||||
snapshot_util::GetCurrentCheckpointFile(shard_directory,
|
||||
current_checkpoint_id),
|
||||
writer_iterator->dataset()->compression_, kSnapshotFileFormatVersion,
|
||||
writer_iterator->dataset()->output_dtypes(), &writer));
|
||||
|
||||
while (true) {
|
||||
BufferElement be;
|
||||
DequeueTensors(&be);
|
||||
|
||||
if (be.end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(writer->Close());
|
||||
break;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If both the writer `mu_` and this `deque_mu_` needs to be acquired, the
|
||||
// writer `mu_` must be acquired first.
|
||||
mutex deque_mu_;
|
||||
|
||||
std::deque<BufferElement> deque_ TF_GUARDED_BY(deque_mu_);
|
||||
|
||||
// This has to be last. During destruction, we need to make sure that
|
||||
// thread_ is destroyed first as the thread destructor blocks on thread
|
||||
// completion. If there are other member variables after this, they may get
|
||||
// destroyed first before the thread finishes, potentially causing the
|
||||
// thread to access invalid memory.
|
||||
std::unique_ptr<Thread> thread_;
|
||||
};
|
||||
|
||||
Status GetShardIndex(std::vector<Tensor>* tensors, int64* shard_index)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
Status WriteMetadataFile(Env* env, bool finalized)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
void StopWriterThreads(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
|
||||
absl::flat_hash_map<int64, std::unique_ptr<WriterThread>> writer_threads_
|
||||
TF_GUARDED_BY(mu_);
|
||||
Status writer_status_ TF_GUARDED_BY(mu_);
|
||||
bool writers_closed_ TF_GUARDED_BY(mu_);
|
||||
|
||||
uint64 run_id_ TF_GUARDED_BY(mu_);
|
||||
tstring run_dir_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// Stores the ID of the current checkpoint .snapshot file being read. See top
|
||||
// of this file for the directory layout.
|
||||
uint64 current_checkpoint_id_ TF_GUARDED_BY(mu_);
|
||||
|
||||
std::unique_ptr<InstantiatedCapturedFunction> instantiated_shard_func_
|
||||
TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
class SnapshotDatasetV2Op::Dataset::Iterator::Passthrough
|
||||
: public DatasetIterator<Dataset> {
|
||||
public:
|
||||
static constexpr const char* const kIteratorName = "Passthrough";
|
||||
|
||||
explicit Passthrough(const Params& params);
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override;
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override;
|
||||
|
||||
protected:
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override;
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Dataset(
|
||||
OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
|
||||
const std::string& path, const std::string& compression,
|
||||
std::unique_ptr<CapturedFunction> reader_func,
|
||||
std::unique_ptr<CapturedFunction> shard_func)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
hash_(hash),
|
||||
path_(path),
|
||||
compression_(compression),
|
||||
reader_func_(std::move(reader_func)),
|
||||
shard_func_(std::move(shard_func)) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::~Dataset() { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase>
|
||||
SnapshotDatasetV2Op::Dataset::MakeIteratorInternal(const string& prefix) const {
|
||||
return absl::make_unique<Iterator>(
|
||||
Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")});
|
||||
}
|
||||
|
||||
const DataTypeVector& SnapshotDatasetV2Op::Dataset::output_dtypes() const {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>&
|
||||
SnapshotDatasetV2Op::Dataset::output_shapes() const {
|
||||
return input_->output_shapes();
|
||||
}
|
||||
|
||||
string SnapshotDatasetV2Op::Dataset::DebugString() const {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
int64 SnapshotDatasetV2Op::Dataset::Cardinality() const {
|
||||
return input_->Cardinality();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::CheckExternalState() const {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
|
||||
SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
|
||||
Node* path = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(path_, &path));
|
||||
|
||||
std::vector<Node*> reader_func_other_args;
|
||||
DataTypeVector reader_func_other_args_types;
|
||||
TF_RETURN_IF_ERROR(reader_func_->AddToGraph(ctx, b, &reader_func_other_args,
|
||||
&reader_func_other_args_types));
|
||||
|
||||
std::vector<Node*> shard_func_other_args;
|
||||
DataTypeVector shard_func_other_args_types;
|
||||
TF_RETURN_IF_ERROR(shard_func_->AddToGraph(ctx, b, &shard_func_other_args,
|
||||
&shard_func_other_args_types));
|
||||
|
||||
AttrValue compression_attr;
|
||||
b->BuildAttrValue(compression_, &compression_attr);
|
||||
|
||||
AttrValue reader_func_attr;
|
||||
b->BuildAttrValue(reader_func_->func(), &reader_func_attr);
|
||||
|
||||
AttrValue shard_func_attr;
|
||||
b->BuildAttrValue(shard_func_->func(), &shard_func_attr);
|
||||
|
||||
AttrValue reader_func_arguments_types_attr;
|
||||
b->BuildAttrValue(reader_func_other_args_types,
|
||||
&reader_func_arguments_types_attr);
|
||||
|
||||
AttrValue shard_func_arguments_types_attr;
|
||||
b->BuildAttrValue(shard_func_other_args_types,
|
||||
&shard_func_arguments_types_attr);
|
||||
|
||||
return b->AddDataset(
|
||||
this,
|
||||
/*inputs=*/
|
||||
{std::make_pair(0, input_graph_node), std::make_pair(1, path)},
|
||||
/*list_inputs=*/
|
||||
{std::make_pair(2, reader_func_other_args),
|
||||
std::make_pair(3, shard_func_other_args)},
|
||||
/*attrs=*/
|
||||
{{kCompression, compression_attr},
|
||||
{kReaderFunc, reader_func_attr},
|
||||
{kShardFunc, shard_func_attr},
|
||||
{kReaderFuncTarguments, reader_func_arguments_types_attr},
|
||||
{kShardFuncTarguments, shard_func_arguments_types_attr}},
|
||||
output);
|
||||
}
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
index_(0),
|
||||
hash_dir_(HashDirectory(dataset()->path_, dataset()->hash_)) {}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
|
||||
IteratorContext* ctx) {
|
||||
return ctx->env()->RecursivelyCreateDir(hash_dir_);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::SaveInternal(
|
||||
SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||
mutex_lock l(mu_);
|
||||
if (iterator_ != nullptr) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIteratorMode),
|
||||
static_cast<int64>(mode_)));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kGraphHashDirectory), hash_dir_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::RestoreInternal(
|
||||
IteratorContext* ctx, IteratorStateReader* reader) {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
if (reader->Contains(full_name(kIteratorMode))) {
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, reader));
|
||||
return RestoreInput(ctx, reader, iterator_);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
|
||||
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) {
|
||||
mutex_lock l(mu_);
|
||||
if (iterator_ == nullptr) {
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
|
||||
}
|
||||
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||
// should not be necessary, but the additional
|
||||
// `{Reader,Writer,Passthrough}::kIteratorName` added to the prefix passed to
|
||||
// `iterator_` when it was created prevents the model from identifying this
|
||||
// iterator as the output of `iterator_`.
|
||||
RecordStop(ctx);
|
||||
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
index_++;
|
||||
RecordStart(ctx);
|
||||
return s;
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
|
||||
IteratorContext* ctx, IteratorStateReader* reader)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (reader != nullptr) {
|
||||
// Check whether the computed hash directory is the same.
|
||||
tstring hash_dir;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kGraphHashDirectory), &hash_dir));
|
||||
if (hash_dir != hash_dir_) {
|
||||
return errors::DataLoss(
|
||||
"Dataset has changed while restoring from the checkpoint. Old hash "
|
||||
"directory: ",
|
||||
hash_dir, "; new hash directory: ", hash_dir_);
|
||||
}
|
||||
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(
|
||||
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
|
||||
if (!file_exists) {
|
||||
return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
|
||||
" does not exist any more.");
|
||||
}
|
||||
|
||||
int64 iterator_mode;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kIteratorMode), &iterator_mode));
|
||||
mode_ = snapshot_util::Mode(iterator_mode);
|
||||
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_));
|
||||
} else {
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(
|
||||
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
|
||||
|
||||
// `pending_snapshot_expiry_seconds` is a legacy option where we would not
|
||||
// write snapshots that we think were still on-going. We decided that this
|
||||
// would not be necessary as a feature for SnapshotV2, and we would always
|
||||
// write a new snapshot regardless of whether someone else is currently
|
||||
// writing one. Setting this to 0 ensures that all previous snapshots
|
||||
// will be ignored and we will proceed to writing.
|
||||
TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
|
||||
/*mode_string=*/"", file_exists, &metadata,
|
||||
/*pending_snapshot_expiry_seconds=*/0, &mode_));
|
||||
}
|
||||
|
||||
switch (mode_) {
|
||||
case snapshot_util::READER:
|
||||
iterator_ = absl::make_unique<Reader>(
|
||||
Reader::Params{dataset(),
|
||||
absl::StrCat(prefix(), Reader::kIteratorName)},
|
||||
index_);
|
||||
break;
|
||||
case snapshot_util::WRITER:
|
||||
iterator_ = absl::make_unique<Writer>(Writer::Params{
|
||||
dataset(), absl::StrCat(prefix(), Writer::kIteratorName)});
|
||||
break;
|
||||
case snapshot_util::PASSTHROUGH:
|
||||
iterator_ = absl::make_unique<Passthrough>(Passthrough::Params{
|
||||
dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)});
|
||||
break;
|
||||
}
|
||||
return iterator_->Initialize(ctx);
|
||||
}
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
|
||||
int64 start_index)
|
||||
: DatasetIterator<Dataset>(params), start_index_(start_index) {}
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
||||
IteratorContext* ctx) {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
|
||||
|
||||
tstring hash_dir = HashDirectory(dataset()->path_, dataset()->hash_);
|
||||
bool metadata_file_exists;
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(hash_dir, &metadata,
|
||||
&metadata_file_exists));
|
||||
|
||||
auto run_dir = io::JoinPath(hash_dir, metadata.run_id());
|
||||
|
||||
std::vector<std::string> snapshot_shard_dirs;
|
||||
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
|
||||
io::JoinPath(run_dir,
|
||||
absl::StrFormat("%s%s", "*", kShardDirectorySuffix)),
|
||||
&snapshot_shard_dirs));
|
||||
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
|
||||
|
||||
DatasetBase* dataset_of_snapshot_files;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::Reader::MakeNestedDataset(
|
||||
ctx->env(), snapshot_shard_dirs, dataset()->compression_,
|
||||
kSnapshotFileFormatVersion, dataset()->output_dtypes(),
|
||||
dataset()->output_shapes(), start_index_, &dataset_of_snapshot_files));
|
||||
|
||||
Tensor input_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(dataset_of_snapshot_files,
|
||||
&input_dataset_tensor));
|
||||
|
||||
std::vector<Tensor> reader_input;
|
||||
std::vector<Tensor> reader_output;
|
||||
reader_input.push_back(std::move(input_dataset_tensor));
|
||||
|
||||
TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
|
||||
ctx, std::move(reader_input), &reader_output));
|
||||
if (reader_output.size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"reader_func returns more than one argument.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
|
||||
|
||||
// We need to take a reference here as we will use the input_ and
|
||||
// its iterator.
|
||||
input_->Ref();
|
||||
|
||||
return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::GetNextInternal(
|
||||
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) {
|
||||
mutex_lock l(mu_);
|
||||
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
// We do not need to checkpoint the reader as we are rebuilding the reader
|
||||
// datasets from information that is already saved by the main iterator.
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::SaveInternal(
|
||||
SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::RestoreInternal(
|
||||
IteratorContext* ctx, IteratorStateReader* reader) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Writer::Writer(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
writers_closed_(false),
|
||||
run_id_(0),
|
||||
current_checkpoint_id_(0) {}
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Writer::~Writer() {
|
||||
mutex_lock l(mu_);
|
||||
StopWriterThreads(true);
|
||||
}
|
||||
|
||||
void SnapshotDatasetV2Op::Dataset::Iterator::Writer::StopWriterThreads(
|
||||
bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!writers_closed_) {
|
||||
// Push the end of sequence signal to each of the threads to close files.
|
||||
for (auto& writer_thread : writer_threads_) {
|
||||
writer_thread.second->StopThread();
|
||||
}
|
||||
|
||||
writer_threads_.clear();
|
||||
writers_closed_ = mark_closed;
|
||||
}
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
|
||||
Env* env, bool finalized) {
|
||||
DCHECK(!run_dir_.empty());
|
||||
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
metadata.set_creation_timestamp(EnvTime::NowMicros());
|
||||
metadata.set_graph_hash(absl::StrFormat("%d", dataset()->hash_));
|
||||
metadata.set_run_id(absl::StrFormat("%d", run_id_));
|
||||
metadata.set_version(kSnapshotFileFormatVersion);
|
||||
for (const auto& output_dtype : dataset()->output_dtypes()) {
|
||||
metadata.add_dtype(output_dtype);
|
||||
}
|
||||
metadata.set_finalized(finalized);
|
||||
tstring hash_directory = HashDirectory(dataset()->path_, dataset()->hash_);
|
||||
|
||||
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(hash_directory));
|
||||
return snapshot_util::WriteMetadataFile(hash_directory, &metadata);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
|
||||
IteratorContext* ctx) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->shard_func_->Instantiate(ctx, &instantiated_shard_func_));
|
||||
|
||||
return dataset()->input_->MakeIterator(
|
||||
ctx, this, strings::StrCat(prefix(), "::WriterIterator"), &input_impl_);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
|
||||
std::vector<Tensor>* tensors, int64* shard_index) {
|
||||
std::vector<Tensor> output_tensors;
|
||||
|
||||
// Run the shard function
|
||||
TF_RETURN_IF_ERROR(
|
||||
instantiated_shard_func_->RunInstantiated(*tensors, &output_tensors));
|
||||
|
||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||
output_tensors[0].NumElements() != 1) {
|
||||
return errors::InvalidArgument("`shard_func` must return a scalar int64.");
|
||||
}
|
||||
|
||||
// Create writable files if we see an index bigger than our current files.
|
||||
*shard_index = output_tensors[0].flat<int64>()(0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
|
||||
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) {
|
||||
*end_of_sequence = false;
|
||||
WriterThread* current_writer_thread;
|
||||
|
||||
{
|
||||
std::vector<Tensor> output_tensors;
|
||||
mutex_lock l(mu_);
|
||||
|
||||
// We initialize late here because restoring from checkpoint comes after the
|
||||
// the Initialize call. We cannot initialize within Initialize() because
|
||||
// we cannot determine whether we should overwrite an existing metadata
|
||||
// file or not before `RestoreInternal` is potentially called.
|
||||
if (run_dir_.empty()) {
|
||||
run_id_ = random::New64();
|
||||
|
||||
// Creates the run directory.
|
||||
run_dir_ = RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_),
|
||||
run_id_);
|
||||
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
|
||||
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
|
||||
}
|
||||
|
||||
// Writers have either encountered an error or are closed.
|
||||
if (!writer_status_.ok() || writers_closed_) {
|
||||
*end_of_sequence = true;
|
||||
return writer_status_;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||
|
||||
// Finalize metadata file when we are at the end of the iterator.
|
||||
if (*end_of_sequence) {
|
||||
StopWriterThreads(/*mark_closed=*/true);
|
||||
TF_RETURN_IF_ERROR(writer_status_);
|
||||
return WriteMetadataFile(ctx->env(), /*finalized=*/true);
|
||||
}
|
||||
|
||||
int64 shard_index = 0;
|
||||
TF_RETURN_IF_ERROR(GetShardIndex(out_tensors, &shard_index));
|
||||
|
||||
// If the index does not exist, we will start a new thread.
|
||||
if (writer_threads_.count(shard_index) == 0) {
|
||||
const tstring snapshot_shard_directory =
|
||||
SnapshotShardDirectory(run_dir_, shard_index);
|
||||
auto thread_data = std::make_unique<WriterThread>(
|
||||
this, ctx->env(), shard_index, snapshot_shard_directory,
|
||||
current_checkpoint_id_);
|
||||
writer_threads_.insert({shard_index, std::move(thread_data)});
|
||||
}
|
||||
current_writer_thread = writer_threads_[shard_index].get();
|
||||
}
|
||||
|
||||
current_writer_thread->EnqueueTensors(*out_tensors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::SaveInternal(
|
||||
SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kRunId), static_cast<int64>(run_id_)));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kCurrentCheckpointId),
|
||||
static_cast<int64>(current_checkpoint_id_)));
|
||||
|
||||
StopWriterThreads(/*mark_closed=*/false);
|
||||
writer_threads_.clear();
|
||||
current_checkpoint_id_++;
|
||||
|
||||
return SaveInput(ctx, writer, input_impl_);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
|
||||
IteratorContext* ctx, IteratorStateReader* reader) {
|
||||
mutex_lock l(mu_);
|
||||
int64 run_id_signed;
|
||||
int64 current_checkpoint_id;
|
||||
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_signed));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointId),
|
||||
¤t_checkpoint_id));
|
||||
|
||||
run_id_ = static_cast<uint64>(run_id_signed);
|
||||
run_dir_ =
|
||||
RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_), run_id_);
|
||||
current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
|
||||
|
||||
return RestoreInput(ctx, reader, input_impl_);
|
||||
}
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Passthrough(
|
||||
const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Initialize(
|
||||
IteratorContext* ctx) {
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::GetNextInternal(
|
||||
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) {
|
||||
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::SaveInternal(
|
||||
SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||
return SaveInput(ctx, writer, input_impl_);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::RestoreInternal(
|
||||
IteratorContext* ctx, IteratorStateReader* reader) {
|
||||
return RestoreInput(ctx, reader, input_impl_);
|
||||
}
|
||||
|
||||
SnapshotDatasetV2Op::SnapshotDatasetV2Op(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
|
||||
FunctionMetadata::Params reader_params;
|
||||
FunctionMetadata::Params shard_params;
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
|
||||
|
||||
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, reader_params,
|
||||
&reader_func_metadata_));
|
||||
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, shard_params,
|
||||
&shard_func_metadata_));
|
||||
}
|
||||
|
||||
void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
tstring path;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
|
||||
|
||||
// Computes the hash of the preceding items in the graph.
|
||||
uint64 graph_hash;
|
||||
GraphDef graph_def;
|
||||
SerializationContext::Params params;
|
||||
std::vector<std::pair<string, Tensor>> input_list;
|
||||
params.input_list = &input_list;
|
||||
params.external_state_policy =
|
||||
SerializationContext::ExternalStatePolicy::kIgnore;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def));
|
||||
OP_REQUIRES_OK(ctx, HashGraph(graph_def, &graph_hash));
|
||||
|
||||
std::unique_ptr<CapturedFunction> reader_func;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CapturedFunction::Create(ctx, reader_func_metadata_,
|
||||
kReaderFuncOtherArgs, &reader_func));
|
||||
std::unique_ptr<CapturedFunction> shard_func;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CapturedFunction::Create(ctx, shard_func_metadata_,
|
||||
kShardFuncOtherArgs, &shard_func));
|
||||
|
||||
*output = new SnapshotDatasetV2Op::Dataset(
|
||||
ctx, input, graph_hash, path, compression_, std::move(reader_func),
|
||||
std::move(shard_func));
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("SnapshotDatasetV2").Device(DEVICE_CPU),
|
||||
SnapshotDatasetV2Op);
|
||||
} // namespace
|
||||
|
||||
// ==== Legacy Snapshot Implementation (Deprecated) ====
|
||||
|
||||
namespace {
|
||||
|
||||
// Defaults to 10 GiB per shard.
|
||||
|
@ -0,0 +1,83 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/op_requires.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace experimental {
|
||||
|
||||
const int64 kSnapshotFileFormatVersion = 1;
|
||||
|
||||
class SnapshotDatasetV2Op : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kDatasetType = "Snapshot";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kCompression = "compression";
|
||||
static constexpr const char* const kReaderFunc = "reader_func";
|
||||
static constexpr const char* const kShardFunc = "shard_func";
|
||||
static constexpr const char* const kReaderFuncOtherArgs =
|
||||
"reader_func_other_args";
|
||||
static constexpr const char* const kShardFuncOtherArgs =
|
||||
"shard_func_other_args";
|
||||
static constexpr const char* const kReaderFuncTarguments =
|
||||
"Treader_func_args";
|
||||
static constexpr const char* const kShardFuncTarguments = "Tshard_func_args";
|
||||
|
||||
explicit SnapshotDatasetV2Op(OpKernelConstruction* ctx);
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override;
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
|
||||
const int graph_def_version_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
|
||||
std::string compression_;
|
||||
|
||||
std::shared_ptr<FunctionMetadata> reader_func_metadata_;
|
||||
std::shared_ptr<FunctionMetadata> shard_func_metadata_;
|
||||
};
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
|
@ -899,6 +899,26 @@ REGISTER_OP("SnapshotDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("SnapshotDatasetV2")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("path: string")
|
||||
.Input("reader_func_other_args: Treader_func_args")
|
||||
.Input("shard_func_other_args: Tshard_func_args")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("compression: string = ''")
|
||||
.Attr("reader_func: func")
|
||||
.Attr("shard_func: func")
|
||||
.Attr("Treader_func_args: list(type) >= 0")
|
||||
.Attr("Tshard_func_args: list(type) >= 0")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `path` should be a scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("SqlDataset")
|
||||
.Input("driver_name: string")
|
||||
.Input("data_source_name: string")
|
||||
|
@ -75,6 +75,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@sample_from_datasets
|
||||
@@scan
|
||||
@@shuffle_and_repeat
|
||||
@@snapshot
|
||||
@@take_while
|
||||
@@to_variant
|
||||
@@unbatch
|
||||
@ -128,6 +129,7 @@ from tensorflow.python.data.experimental.ops.readers import SqlDataset
|
||||
from tensorflow.python.data.experimental.ops.resampling import rejection_resample
|
||||
from tensorflow.python.data.experimental.ops.scan_ops import scan
|
||||
from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
|
||||
from tensorflow.python.data.experimental.ops.snapshot import snapshot
|
||||
from tensorflow.python.data.experimental.ops.stats_aggregator import StatsAggregator
|
||||
from tensorflow.python.data.experimental.ops.stats_ops import bytes_produced_stats
|
||||
from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
|
||||
|
@ -753,9 +753,12 @@ tf_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/data/experimental/ops:readers",
|
||||
"//tensorflow/python/data/experimental/ops:snapshot",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -199,7 +199,9 @@ tf_py_test(
|
||||
deps = [
|
||||
":dataset_serialization_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -34,6 +34,97 @@ class SnapshotDatasetSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_snapshot_dataset(self, repeat=False):
|
||||
|
||||
def ds_fn():
|
||||
self._snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
|
||||
if not os.path.exists(self._snapshot_dir):
|
||||
os.mkdir(self._snapshot_dir)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
if repeat:
|
||||
dataset = dataset.repeat(2)
|
||||
return dataset
|
||||
|
||||
return ds_fn
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCheckpointBeforeEpochEndNoRepeat(self):
|
||||
ds_fn = self._build_snapshot_dataset(repeat=False)
|
||||
outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False)
|
||||
self.assertSequenceEqual(outputs, range(50))
|
||||
outputs.extend(
|
||||
self.gen_outputs(ds_fn, [], 50, ckpt_saved=True, verify_exhausted=True))
|
||||
self.assertSequenceEqual(outputs, range(100))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCheckpointBeforeOneEpochWithReading(self):
|
||||
ds_fn = self._build_snapshot_dataset(repeat=True)
|
||||
|
||||
# Generate 50 entries from iterator and save checkpoint.
|
||||
outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False)
|
||||
self.assertSequenceEqual(outputs, list(range(50)))
|
||||
|
||||
# Restore from checkpoint and produce the rest of the elements from the
|
||||
# iterator.
|
||||
t = self.gen_outputs(ds_fn, [], 150, ckpt_saved=True, verify_exhausted=True)
|
||||
outputs.extend(t)
|
||||
self.assertSequenceEqual(
|
||||
outputs,
|
||||
list(range(50)) + list(range(50, 100)) + list(range(100)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCheckpointBeforeOneEpochThenRunAFewSteps(self):
|
||||
ds_fn = self._build_snapshot_dataset(repeat=False)
|
||||
outputs = self.gen_outputs(
|
||||
ds_fn, [10], 20, verify_exhausted=False, save_checkpoint_at_end=False)
|
||||
self.assertSequenceEqual(outputs, range(20))
|
||||
|
||||
outputs = outputs[:10]
|
||||
outputs.extend(
|
||||
self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True))
|
||||
self.assertSequenceEqual(outputs, range(100))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCheckpointAfterOneEpoch(self):
|
||||
ds_fn = self._build_snapshot_dataset(repeat=True)
|
||||
|
||||
# Generate 110 entries from iterator and save checkpoint.
|
||||
outputs = self.gen_outputs(ds_fn, [], 110, verify_exhausted=False)
|
||||
self.assertSequenceEqual(outputs, list(range(100)) + list(range(10)))
|
||||
|
||||
# Restore from checkpoint and produce the rest of the elements from the
|
||||
# iterator.
|
||||
t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)
|
||||
outputs.extend(t)
|
||||
self.assertSequenceEqual(
|
||||
outputs,
|
||||
list(range(100)) + list(range(10)) + list(range(10, 100)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCheckpointAfterOneEpochRunFewSteps(self):
|
||||
ds_fn = self._build_snapshot_dataset(repeat=True)
|
||||
|
||||
# Generate 120 entries from iterator and save checkpoint at 110.
|
||||
outputs = self.gen_outputs(
|
||||
ds_fn, [110], 120, verify_exhausted=False, save_checkpoint_at_end=False)
|
||||
self.assertSequenceEqual(outputs, list(range(100)) + list(range(20)))
|
||||
|
||||
# Restore from checkpoint and produce the rest of the elements from the
|
||||
# iterator.
|
||||
outputs = outputs[:110]
|
||||
t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)
|
||||
outputs.extend(t)
|
||||
self.assertSequenceEqual(
|
||||
outputs,
|
||||
list(range(100)) + list(range(10)) + list(range(10, 100)))
|
||||
|
||||
|
||||
class LegacySnapshotDatasetSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_snapshot_dataset(self,
|
||||
num_threads=1,
|
||||
repeat=False,
|
||||
|
@ -17,11 +17,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
from tensorflow.python.data.experimental.ops import snapshot
|
||||
@ -40,6 +42,285 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
|
||||
def setUp(self):
|
||||
super(SnapshotDatasetTest, self).setUp()
|
||||
tmpdir = self.get_temp_dir()
|
||||
tmpdir = os.path.join(tmpdir, "snapshot")
|
||||
os.mkdir(tmpdir)
|
||||
self._snapshot_dir = tmpdir
|
||||
|
||||
def tearDown(self):
|
||||
super(SnapshotDatasetTest, self).tearDown()
|
||||
shutil.rmtree(self._snapshot_dir)
|
||||
|
||||
def createTFRecords(self, num_files=10, num_records=100):
|
||||
self._num_files = num_files
|
||||
self._num_records = num_records
|
||||
self._test_filenames = self._createFiles()
|
||||
|
||||
def removeTFRecords(self):
|
||||
for filename in self._test_filenames:
|
||||
os.remove(filename)
|
||||
self._test_filenames = []
|
||||
self._num_files = None
|
||||
self._num_records = None
|
||||
|
||||
def assertDatasetProducesSet(self, dataset, expected):
|
||||
actual = []
|
||||
next_fn = self.getNext(dataset)
|
||||
for _ in range(len(expected)):
|
||||
elem = self.evaluate(next_fn())
|
||||
actual.append(elem)
|
||||
self.assertCountEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_fn())
|
||||
|
||||
def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
|
||||
num_runs_per_fingerprint,
|
||||
num_snapshot_shards_per_run):
|
||||
dirlist_raw = os.listdir(directory)
|
||||
dirlist = []
|
||||
|
||||
# Ignore the graphdef pbtxts we write for debugging purposes.
|
||||
for i in range(len(dirlist_raw)):
|
||||
if not dirlist_raw[i].endswith("-graph.pbtxt"):
|
||||
dirlist.append(dirlist_raw[i])
|
||||
|
||||
self.assertLen(dirlist, num_fingerprints)
|
||||
|
||||
for i in range(num_fingerprints):
|
||||
fingerprint_dir = os.path.join(directory, dirlist[i])
|
||||
fingerprint_dir_list = sorted(os.listdir(fingerprint_dir))
|
||||
self.assertLen(fingerprint_dir_list, num_runs_per_fingerprint + 1)
|
||||
self.assertEqual(fingerprint_dir_list[num_runs_per_fingerprint],
|
||||
"snapshot.metadata")
|
||||
|
||||
for j in range(num_runs_per_fingerprint):
|
||||
run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j])
|
||||
run_dirlist = sorted(os.listdir(run_dir))
|
||||
self.assertLen(run_dirlist, num_snapshot_shards_per_run)
|
||||
|
||||
file_counter = 0
|
||||
for filename in run_dirlist:
|
||||
self.assertEqual(filename, "%08d.shard" % file_counter)
|
||||
file_counter += 1
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCreateSnapshotDataset(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3])
|
||||
dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadSnapshotDatasetDefault(self):
|
||||
self.createTFRecords()
|
||||
filenames = self._test_filenames
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in range(0, 10)
|
||||
for r in range(0, 100)
|
||||
]
|
||||
|
||||
dataset = core_readers._TFRecordDataset(filenames)
|
||||
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
self.assertDatasetProduces(dataset, expected)
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=1,
|
||||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||
|
||||
self.removeTFRecords()
|
||||
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
self.assertDatasetProduces(dataset2, expected)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadSnapshotDatasetCustomShardFn(self):
|
||||
self.createTFRecords()
|
||||
filenames = self._test_filenames
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in range(0, 10)
|
||||
for r in range(0, 100)
|
||||
]
|
||||
|
||||
dataset = core_readers._TFRecordDataset(filenames)
|
||||
dataset = dataset.apply(
|
||||
snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: np.int64(0)))
|
||||
self.assertDatasetProduces(dataset, expected)
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=1,
|
||||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=1)
|
||||
|
||||
self.removeTFRecords()
|
||||
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||
dataset2 = dataset2.apply(
|
||||
snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: 0))
|
||||
self.assertDatasetProduces(dataset2, expected)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadSnapshotDatasetCustomReaderFn(self):
|
||||
self.createTFRecords()
|
||||
filenames = self._test_filenames
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in range(0, 10)
|
||||
for r in range(0, 100)
|
||||
]
|
||||
|
||||
dataset = core_readers._TFRecordDataset(filenames)
|
||||
dataset = dataset.apply(
|
||||
snapshot.snapshot(
|
||||
self._snapshot_dir,
|
||||
reader_func=(
|
||||
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
|
||||
lambda x: x,
|
||||
cycle_length=4,
|
||||
num_parallel_calls=4))))
|
||||
self.assertDatasetProduces(dataset, expected)
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=1,
|
||||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||
|
||||
self.removeTFRecords()
|
||||
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||
dataset2 = dataset2.apply(
|
||||
snapshot.snapshot(
|
||||
self._snapshot_dir,
|
||||
reader_func=(
|
||||
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
|
||||
lambda x: x,
|
||||
cycle_length=4,
|
||||
num_parallel_calls=4))))
|
||||
self.assertDatasetProducesSet(dataset2, expected)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSnapshotDatasetInvalidShardFn(self):
|
||||
dataset = dataset_ops.Dataset.range(1000)
|
||||
with self.assertRaises(TypeError):
|
||||
dataset = dataset.apply(
|
||||
snapshot.snapshot(
|
||||
self._snapshot_dir, shard_func=lambda _: "invalid_fn"))
|
||||
next_fn = self.getNext(dataset)
|
||||
self.evaluate(next_fn())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSnapshotDatasetInvalidReaderFn(self):
|
||||
dataset = dataset_ops.Dataset.range(1000)
|
||||
with self.assertRaises(TypeError):
|
||||
dataset = dataset.apply(
|
||||
snapshot.snapshot(self._snapshot_dir, reader_func=lambda x: x + 1))
|
||||
next_fn = self.getNext(dataset)
|
||||
self.evaluate(next_fn())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWriteSnapshotDatasetSimple(self):
|
||||
dataset = dataset_ops.Dataset.range(1000)
|
||||
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
self.assertDatasetProduces(dataset, list(range(1000)))
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=1,
|
||||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWriteSnapshotDatasetMultipleFingerprints(self):
|
||||
dataset1 = dataset_ops.Dataset.range(1000)
|
||||
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
self.assertDatasetProduces(dataset1, list(range(1000)))
|
||||
|
||||
dataset2 = dataset_ops.Dataset.range(2000)
|
||||
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
self.assertDatasetProduces(dataset2, list(range(2000)))
|
||||
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=2,
|
||||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self):
|
||||
dataset1 = dataset_ops.Dataset.range(1000)
|
||||
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
self.assertDatasetProduces(dataset1, list(range(1000)))
|
||||
dataset2 = dataset_ops.Dataset.range(1000)
|
||||
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
self.assertDatasetProduces(dataset2, list(range(1000)))
|
||||
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=1,
|
||||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self):
|
||||
dataset1 = dataset_ops.Dataset.range(1000)
|
||||
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
next1 = self.getNext(dataset1)
|
||||
for i in range(500):
|
||||
self.assertEqual(i, self.evaluate(next1()))
|
||||
|
||||
dataset2 = dataset_ops.Dataset.range(1000)
|
||||
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
next2 = self.getNext(dataset2)
|
||||
for i in range(500):
|
||||
self.assertEqual(i, self.evaluate(next2()))
|
||||
|
||||
for i in range(500, 1000):
|
||||
self.assertEqual(i, self.evaluate(next1()))
|
||||
self.assertEqual(i, self.evaluate(next2()))
|
||||
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=1,
|
||||
num_runs_per_fingerprint=2,
|
||||
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWriteSnapshotCustomShardFunction(self):
|
||||
dataset = dataset_ops.Dataset.range(1000)
|
||||
dataset = dataset.enumerate()
|
||||
dataset = dataset.apply(
|
||||
snapshot.snapshot(self._snapshot_dir, shard_func=lambda i, _: i % 2))
|
||||
dataset = dataset.map(lambda _, elem: elem)
|
||||
self.assertDatasetProduces(dataset, list(range(1000)))
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=1,
|
||||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=2)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testWriteSnapshotDatasetWithTuples(self):
|
||||
dataset1 = dataset_ops.Dataset.range(0, 1000)
|
||||
dataset2 = dataset_ops.Dataset.range(1000, 2000)
|
||||
dataset3 = dataset_ops.Dataset.range(2000, 3000)
|
||||
dataset4 = dataset_ops.Dataset.range(3000, 4000)
|
||||
|
||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2, dataset3, dataset4))
|
||||
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
next1 = self.getNext(dataset)
|
||||
for i in range(0, 1000):
|
||||
self.assertEqual((i, i + 1000, i + 2000, i + 3000),
|
||||
self.evaluate(next1()))
|
||||
self.assertSnapshotDirectoryContains(
|
||||
self._snapshot_dir,
|
||||
num_fingerprints=1,
|
||||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||
|
||||
|
||||
class LegacySnapshotDatasetTest(
|
||||
reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(LegacySnapshotDatasetTest, self).setUp()
|
||||
self.removeTFRecords()
|
||||
tmpdir = self.get_temp_dir()
|
||||
tmpdir = os.path.join(tmpdir, "snapshot")
|
||||
@ -47,7 +328,7 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
self.snapshot_dir = tmpdir
|
||||
|
||||
def tearDown(self):
|
||||
super(SnapshotDatasetTest, self).tearDown()
|
||||
super(LegacySnapshotDatasetTest, self).tearDown()
|
||||
shutil.rmtree(self.snapshot_dir)
|
||||
|
||||
def removeTFRecords(self):
|
||||
@ -63,8 +344,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
def makeSnapshotDirectory(self):
|
||||
return self.snapshot_dir
|
||||
|
||||
def assertSnapshotDirectoryContains(
|
||||
self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files):
|
||||
def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
|
||||
num_runs_per_fp, num_snapshot_files):
|
||||
dirlist_raw = os.listdir(directory)
|
||||
dirlist = []
|
||||
|
||||
@ -465,8 +746,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
]),
|
||||
combinations.combine(threads=2, size=[1, 2]) +
|
||||
combinations.combine(threads=8, size=[1, 4, 8]))))
|
||||
def testReadSnapshotBackAfterMultiThreadedWrite(
|
||||
self, compression, threads, size):
|
||||
def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads,
|
||||
size):
|
||||
self.setUpTFRecord()
|
||||
filenames = self.test_filenames
|
||||
|
||||
|
@ -17,12 +17,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
COMPRESSION_GZIP = "GZIP"
|
||||
COMPRESSION_SNAPPY = "SNAPPY"
|
||||
@ -99,6 +103,8 @@ class _LegacySnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "Use `tf.data.experimental.snapshot(...)` instead.")
|
||||
def legacy_snapshot(path,
|
||||
compression=None,
|
||||
reader_path_prefix=None,
|
||||
@ -186,3 +192,165 @@ def legacy_snapshot(path,
|
||||
snapshot_name=snapshot_name)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
||||
class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
"""A dataset that allows saving and re-use of already processed data."""
|
||||
|
||||
def __init__(self,
|
||||
input_dataset,
|
||||
path,
|
||||
shard_func,
|
||||
compression=None,
|
||||
reader_func=None,
|
||||
pending_snapshot_expiry_seconds=None,
|
||||
use_legacy_function=False):
|
||||
|
||||
if reader_func is None:
|
||||
reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda
|
||||
lambda x: x,
|
||||
cycle_length=multiprocessing.cpu_count(),
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
self._path = path
|
||||
self._compression = compression
|
||||
|
||||
self._reader_func = dataset_ops.StructuredFunctionWrapper(
|
||||
reader_func,
|
||||
self._transformation_name() + ".reader_func",
|
||||
# Dataset of datasets of input elements
|
||||
input_structure=dataset_ops.DatasetSpec(
|
||||
dataset_ops.DatasetSpec(input_dataset.element_spec)),
|
||||
use_legacy_function=use_legacy_function)
|
||||
self._shard_func = dataset_ops.StructuredFunctionWrapper(
|
||||
shard_func,
|
||||
self._transformation_name() + ".shard_func",
|
||||
dataset=input_dataset,
|
||||
use_legacy_function=use_legacy_function)
|
||||
|
||||
if ((not self._shard_func.output_structure.is_compatible_with(
|
||||
tensor_spec.TensorSpec([], dtypes.int32))) and
|
||||
(not self._shard_func.output_structure.is_compatible_with(
|
||||
tensor_spec.TensorSpec([], dtypes.int64)))):
|
||||
raise TypeError(
|
||||
"shard_func must return a 0-dimension tensor containing an int.")
|
||||
|
||||
variant_tensor = ged_ops.snapshot_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
path,
|
||||
self._reader_func.function.captured_inputs,
|
||||
self._shard_func.function.captured_inputs,
|
||||
compression=compression,
|
||||
reader_func=self._reader_func.function,
|
||||
shard_func=self._shard_func.function,
|
||||
**self._flat_structure)
|
||||
super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._reader_func, self._shard_func]
|
||||
|
||||
def _transformation_name(self):
|
||||
return "Dataset.snapshot"
|
||||
|
||||
|
||||
@tf_export("data.experimental.snapshot")
|
||||
def snapshot(path, compression="AUTO", reader_func=None, shard_func=None):
|
||||
"""API to persist the output of the input dataset.
|
||||
|
||||
The snapshot API allows users to transparently persist the output of their
|
||||
preprocessing pipeline to disk, and materialize the pre-processed data on a
|
||||
different training run.
|
||||
|
||||
This API enables repeated preprocessing steps to be consolidated, and allows
|
||||
re-use of already processed data, trading off disk storage and network
|
||||
bandwidth for freeing up more valuable CPU resources and accelerator compute
|
||||
time.
|
||||
|
||||
https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md
|
||||
has detailed design documentation of this feature.
|
||||
|
||||
Users can specify various options to control the behavior of snapshot,
|
||||
including how snapshots are read from and written to by passing in
|
||||
user-defined functions to the `reader_func` and `shard_func` parameters.
|
||||
|
||||
`shard_func` is a user specified function that maps input elements to snapshot
|
||||
shards.
|
||||
|
||||
Users may want to specify this function to control how snapshot files should
|
||||
be written to disk. Below is an example of how a potential shard_func could
|
||||
be written.
|
||||
|
||||
```python
|
||||
dataset = ...
|
||||
dataset = dataset.enumerate()
|
||||
dataset = dataset.apply(tf.data.experimental.snapshot(
|
||||
shard_func=lambda x, y: x % NUM_SHARDS, ...))
|
||||
dataset = dataset.map(lambda x, y: y)
|
||||
```
|
||||
|
||||
`reader_func` is a user specified function that accepts a single argument:
|
||||
(1) a Dataset of Datasets, each representing a "split" of elements of the
|
||||
original dataset. The cardinality of the input dataset matches the
|
||||
number of the shards specified in the `shard_func` (see above). The function
|
||||
should return a Dataset of elements of the original dataset.
|
||||
|
||||
Users may want specify this function to control how snapshot files should be
|
||||
read from disk, including the amount of shuffling and parallelism.
|
||||
|
||||
Here is an example of a standard reader function a user can define. This
|
||||
function enables both dataset shuffling and parallel reading of datasets:
|
||||
|
||||
```python
|
||||
def user_reader_func(datasets):
|
||||
# shuffle the datasets splits
|
||||
datasets = datasets.shuffle(NUM_CORES)
|
||||
# read datasets in parallel and interleave their elements
|
||||
return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
|
||||
|
||||
dataset = dataset.apply(tf.data.experimental.snapshot(
|
||||
reader_func=user_reader_func))
|
||||
```
|
||||
|
||||
By default, snapshot parallelize reads by the number of cores available on
|
||||
the system, but will not attempt to shuffle the data.
|
||||
|
||||
Args:
|
||||
path: Required. A directory to use for storing / loading the snapshot to /
|
||||
from.
|
||||
compression: Optional. The type of compression to apply to the snapshot
|
||||
written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None.
|
||||
Defaults to AUTO, which attempts to pick an appropriate compression
|
||||
algorithm for the dataset.
|
||||
reader_func: Optional. A function to control how to read data from snapshot
|
||||
shards.
|
||||
shard_func: Optional. A function to control how to shard data when writing a
|
||||
snapshot.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
`tf.data.Dataset.apply`.
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
"""Actual dataset transformation."""
|
||||
if shard_func is None:
|
||||
dataset = dataset.enumerate()
|
||||
dataset = _SnapshotDataset(
|
||||
input_dataset=dataset,
|
||||
path=path,
|
||||
compression=compression,
|
||||
reader_func=reader_func,
|
||||
# This will not do the right thing where the graph is built on a
|
||||
# different machine than the executor (e.g. Cloud TPUs).
|
||||
shard_func=lambda index, _: index % multiprocessing.cpu_count())
|
||||
return dataset.map(lambda _, elem: elem)
|
||||
else:
|
||||
return _SnapshotDataset(
|
||||
input_dataset=dataset,
|
||||
path=path,
|
||||
compression=compression,
|
||||
reader_func=reader_func,
|
||||
shard_func=shard_func)
|
||||
|
||||
return _apply_fn
|
||||
|
@ -220,6 +220,10 @@ tf_module {
|
||||
name: "shuffle_and_repeat"
|
||||
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "snapshot"
|
||||
argspec: "args=[\'path\', \'compression\', \'reader_func\', \'shard_func\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "take_while"
|
||||
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -3992,6 +3992,10 @@ tf_module {
|
||||
name: "SnapshotDataset"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SnapshotDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SobolSample"
|
||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
@ -188,6 +188,10 @@ tf_module {
|
||||
name: "shuffle_and_repeat"
|
||||
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "snapshot"
|
||||
argspec: "args=[\'path\', \'compression\', \'reader_func\', \'shard_func\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "take_while"
|
||||
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -3992,6 +3992,10 @@ tf_module {
|
||||
name: "SnapshotDataset"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SnapshotDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SobolSample"
|
||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user