Add implementation for snapshot dataset v2.
PiperOrigin-RevId: 314258033 Change-Id: I6151fdc646a297090de6eeeb3254a556ae9d13bc
This commit is contained in:
parent
70387ab55b
commit
a3f393bc95
@ -0,0 +1,41 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SnapshotDatasetV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "input_dataset"
|
||||||
|
description: <<END
|
||||||
|
A variant tensor representing the input dataset.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "path"
|
||||||
|
description: <<END
|
||||||
|
The path we should write snapshots to / read snapshots from.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "compression"
|
||||||
|
description: <<END
|
||||||
|
The type of compression to be applied to the saved snapshot files.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "reader_func"
|
||||||
|
description: <<END
|
||||||
|
Optional. A function to control how to read data from snapshot shards.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "shard_func"
|
||||||
|
description: <<END
|
||||||
|
Optional. A function to control how to shard data when writing a snapshot.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Creates a dataset that will write to / read from a snapshot."
|
||||||
|
description: <<END
|
||||||
|
This dataset attempts to determine whether a valid snapshot exists at the
|
||||||
|
`snapshot_path`, and reads from the snapshot in lieu of using `input_dataset`.
|
||||||
|
If not, it will run the preprocessing pipeline as usual, and write out a
|
||||||
|
snapshot of the data processed for future use.
|
||||||
|
END
|
||||||
|
}
|
@ -544,6 +544,7 @@ cc_library(
|
|||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "snapshot_dataset_op",
|
name = "snapshot_dataset_op",
|
||||||
srcs = ["snapshot_dataset_op.cc"],
|
srcs = ["snapshot_dataset_op.cc"],
|
||||||
|
hdrs = ["snapshot_dataset_op.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":snapshot_util",
|
":snapshot_util",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
@ -552,10 +553,16 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/framework:op_requires",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
"//tensorflow/core/grappler:graph_view",
|
||||||
|
"//tensorflow/core/kernels/data:captured_function",
|
||||||
"//tensorflow/core/kernels/data:dataset_utils",
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
"//tensorflow/core/platform:platform_port",
|
"//tensorflow/core/platform:platform_port",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/time",
|
"@com_google_absl//absl/time",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h"
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include "absl/time/clock.h"
|
#include "absl/time/clock.h"
|
||||||
@ -60,6 +62,875 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace experimental {
|
namespace experimental {
|
||||||
|
|
||||||
|
// ==== Snapshot Implementation ====
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/* The current snapshot on-disk layout is as follows:
|
||||||
|
* /user/specified/path/
|
||||||
|
* - graphhash1/
|
||||||
|
* - snapshot.metadata // metadata file
|
||||||
|
* - run1/
|
||||||
|
* - 00000000.shard/ // shard index
|
||||||
|
* // new checkpoint files are created on all threads at once, either
|
||||||
|
* // when a file gets too big, or when a TF checkpoint happens.
|
||||||
|
* - 00000000.snapshot // checkpoint file 0
|
||||||
|
* - 00000001.snapshot // checkpoint file 1
|
||||||
|
* - ...
|
||||||
|
* - 00000001.shard/
|
||||||
|
* - 00000000.snapshot
|
||||||
|
* - 00000001.snapshot
|
||||||
|
* - ...
|
||||||
|
* - 00000002.shard/
|
||||||
|
* - 00000000.snapshot
|
||||||
|
* - 00000001.snapshot
|
||||||
|
* - ...
|
||||||
|
* ...
|
||||||
|
* - run2/
|
||||||
|
* ...
|
||||||
|
* - graphhash2/
|
||||||
|
* ...
|
||||||
|
* - graphhash3/
|
||||||
|
* ...
|
||||||
|
*/
|
||||||
|
|
||||||
|
constexpr const char* const kShardDirectorySuffix = ".shard";
|
||||||
|
|
||||||
|
inline tstring HashDirectory(const tstring& path, const uint64 hash) {
|
||||||
|
return io::JoinPath(path, absl::StrFormat("%d", hash));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline tstring RunDirectory(const tstring& hash_directory,
|
||||||
|
const uint64 run_id) {
|
||||||
|
return io::JoinPath(hash_directory, absl::StrFormat("%d", run_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline tstring SnapshotShardDirectory(const tstring& run_directory,
|
||||||
|
const int64 snapshot_index) {
|
||||||
|
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", snapshot_index,
|
||||||
|
kShardDirectorySuffix));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class SnapshotDatasetV2Op::Dataset : public DatasetBase {
|
||||||
|
public:
|
||||||
|
Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
|
||||||
|
const std::string& path, const std::string& compression,
|
||||||
|
std::unique_ptr<CapturedFunction> reader_func,
|
||||||
|
std::unique_ptr<CapturedFunction> shard_func);
|
||||||
|
|
||||||
|
~Dataset() override;
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
|
const string& prefix) const override;
|
||||||
|
|
||||||
|
const DataTypeVector& output_dtypes() const override;
|
||||||
|
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes() const override;
|
||||||
|
|
||||||
|
string DebugString() const override;
|
||||||
|
|
||||||
|
int64 Cardinality() const override;
|
||||||
|
|
||||||
|
Status CheckExternalState() const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const DatasetBase* input_;
|
||||||
|
const uint64 hash_;
|
||||||
|
const tstring path_;
|
||||||
|
const std::string compression_;
|
||||||
|
|
||||||
|
std::unique_ptr<CapturedFunction> reader_func_;
|
||||||
|
std::unique_ptr<CapturedFunction> shard_func_;
|
||||||
|
|
||||||
|
class Iterator;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SnapshotDatasetV2Op::Dataset::Iterator : public DatasetIterator<Dataset> {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kIteratorMode = "iterator_mode";
|
||||||
|
static constexpr const char* const kIndex = "index";
|
||||||
|
static constexpr const char* const kGraphHashDirectory =
|
||||||
|
"graph_hash_directory";
|
||||||
|
|
||||||
|
explicit Iterator(const Params& params);
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override;
|
||||||
|
|
||||||
|
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
|
IteratorStateWriter* writer) override;
|
||||||
|
|
||||||
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
|
IteratorStateReader* reader) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status InitializeIterator(IteratorContext* ctx, IteratorStateReader* reader);
|
||||||
|
|
||||||
|
int64 index_ TF_GUARDED_BY(mu_);
|
||||||
|
std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
|
||||||
|
snapshot_util::Mode mode_ TF_GUARDED_BY(mu_);
|
||||||
|
const std::string hash_dir_;
|
||||||
|
|
||||||
|
mutex mu_;
|
||||||
|
|
||||||
|
class Reader;
|
||||||
|
class Writer;
|
||||||
|
class Passthrough;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SnapshotDatasetV2Op::Dataset::Iterator::Reader
|
||||||
|
: public DatasetIterator<Dataset> {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kIteratorName = "Reader";
|
||||||
|
|
||||||
|
explicit Reader(const Params& params, int64 start_index);
|
||||||
|
|
||||||
|
~Reader() override;
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override;
|
||||||
|
|
||||||
|
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
|
IteratorStateWriter* writer) override;
|
||||||
|
|
||||||
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
|
IteratorStateReader* reader) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int64 start_index_;
|
||||||
|
|
||||||
|
mutex mu_;
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
DatasetBase* input_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
|
||||||
|
TF_GUARDED_BY(mu_);
|
||||||
|
};
|
||||||
|
|
||||||
|
class SnapshotDatasetV2Op::Dataset::Iterator::Writer
|
||||||
|
: public DatasetIterator<Dataset> {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kIteratorName = "Writer";
|
||||||
|
static constexpr const char* const kRunId = "run_id";
|
||||||
|
static constexpr const char* const kCurrentCheckpointId =
|
||||||
|
"current_checkpoint_id";
|
||||||
|
|
||||||
|
explicit Writer(const Params& params);
|
||||||
|
|
||||||
|
~Writer() override;
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override;
|
||||||
|
|
||||||
|
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
|
IteratorStateWriter* writer) override;
|
||||||
|
|
||||||
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
|
IteratorStateReader* reader) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct BufferElement {
|
||||||
|
std::vector<Tensor> value;
|
||||||
|
bool end_of_sequence = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WriterThread {
|
||||||
|
public:
|
||||||
|
explicit WriterThread(Writer* writer_iterator, Env* env, int64 file_index,
|
||||||
|
const tstring& shard_directory,
|
||||||
|
uint64 current_checkpoint_id) {
|
||||||
|
thread_ = absl::WrapUnique(env->StartThread(
|
||||||
|
ThreadOptions(), absl::StrCat("snapshot_writer_thread_", file_index),
|
||||||
|
[this, writer_iterator, env, shard_directory, current_checkpoint_id] {
|
||||||
|
RunWriterThread(writer_iterator, env, shard_directory,
|
||||||
|
current_checkpoint_id);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void EnqueueTensors(const std::vector<Tensor>& tensors)
|
||||||
|
TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||||
|
// Copy the Tensor to the deque for writing.
|
||||||
|
mutex_lock l(deque_mu_);
|
||||||
|
BufferElement be;
|
||||||
|
be.value = tensors;
|
||||||
|
deque_.push_back(std::move(be));
|
||||||
|
}
|
||||||
|
|
||||||
|
void DequeueTensors(BufferElement* be) TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||||
|
mutex_lock l(deque_mu_);
|
||||||
|
deque_mu_.Await(
|
||||||
|
tensorflow::Condition(this, &WriterThread::DequeIsNotEmpty));
|
||||||
|
*be = deque_.front();
|
||||||
|
deque_.pop_front();
|
||||||
|
}
|
||||||
|
|
||||||
|
void StopThread() TF_LOCKS_EXCLUDED(deque_mu_) {
|
||||||
|
mutex_lock l(deque_mu_);
|
||||||
|
BufferElement be;
|
||||||
|
be.end_of_sequence = true;
|
||||||
|
deque_.push_back(std::move(be));
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunWriterThread(Writer* writer_iterator, Env* env,
|
||||||
|
const tstring& shard_directory,
|
||||||
|
uint64 current_checkpoint_id) {
|
||||||
|
Status s = WriterThreadFn(writer_iterator, env, shard_directory,
|
||||||
|
current_checkpoint_id);
|
||||||
|
if (!s.ok()) {
|
||||||
|
mutex_lock l(writer_iterator->mu_);
|
||||||
|
writer_iterator->writer_status_ = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool DequeIsNotEmpty() TF_EXCLUSIVE_LOCKS_REQUIRED(deque_mu_) {
|
||||||
|
return !deque_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status WriterThreadFn(Writer* writer_iterator, Env* env,
|
||||||
|
const tstring& shard_directory,
|
||||||
|
uint64 current_checkpoint_id) {
|
||||||
|
std::unique_ptr<snapshot_util::Writer> writer;
|
||||||
|
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
|
||||||
|
env,
|
||||||
|
snapshot_util::GetCurrentCheckpointFile(shard_directory,
|
||||||
|
current_checkpoint_id),
|
||||||
|
writer_iterator->dataset()->compression_, kSnapshotFileFormatVersion,
|
||||||
|
writer_iterator->dataset()->output_dtypes(), &writer));
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
BufferElement be;
|
||||||
|
DequeueTensors(&be);
|
||||||
|
|
||||||
|
if (be.end_of_sequence) {
|
||||||
|
TF_RETURN_IF_ERROR(writer->Close());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If both the writer `mu_` and this `deque_mu_` needs to be acquired, the
|
||||||
|
// writer `mu_` must be acquired first.
|
||||||
|
mutex deque_mu_;
|
||||||
|
|
||||||
|
std::deque<BufferElement> deque_ TF_GUARDED_BY(deque_mu_);
|
||||||
|
|
||||||
|
// This has to be last. During destruction, we need to make sure that
|
||||||
|
// thread_ is destroyed first as the thread destructor blocks on thread
|
||||||
|
// completion. If there are other member variables after this, they may get
|
||||||
|
// destroyed first before the thread finishes, potentially causing the
|
||||||
|
// thread to access invalid memory.
|
||||||
|
std::unique_ptr<Thread> thread_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Status GetShardIndex(std::vector<Tensor>* tensors, int64* shard_index)
|
||||||
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
Status WriteMetadataFile(Env* env, bool finalized)
|
||||||
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
void StopWriterThreads(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
mutex mu_;
|
||||||
|
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
absl::flat_hash_map<int64, std::unique_ptr<WriterThread>> writer_threads_
|
||||||
|
TF_GUARDED_BY(mu_);
|
||||||
|
Status writer_status_ TF_GUARDED_BY(mu_);
|
||||||
|
bool writers_closed_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
uint64 run_id_ TF_GUARDED_BY(mu_);
|
||||||
|
tstring run_dir_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// Stores the ID of the current checkpoint .snapshot file being read. See top
|
||||||
|
// of this file for the directory layout.
|
||||||
|
uint64 current_checkpoint_id_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
std::unique_ptr<InstantiatedCapturedFunction> instantiated_shard_func_
|
||||||
|
TF_GUARDED_BY(mu_);
|
||||||
|
};
|
||||||
|
|
||||||
|
class SnapshotDatasetV2Op::Dataset::Iterator::Passthrough
|
||||||
|
: public DatasetIterator<Dataset> {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kIteratorName = "Passthrough";
|
||||||
|
|
||||||
|
explicit Passthrough(const Params& params);
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override;
|
||||||
|
|
||||||
|
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
|
IteratorStateWriter* writer) override;
|
||||||
|
|
||||||
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
|
IteratorStateReader* reader) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<IteratorBase> input_impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::Dataset::Dataset(
|
||||||
|
OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
|
||||||
|
const std::string& path, const std::string& compression,
|
||||||
|
std::unique_ptr<CapturedFunction> reader_func,
|
||||||
|
std::unique_ptr<CapturedFunction> shard_func)
|
||||||
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
|
input_(input),
|
||||||
|
hash_(hash),
|
||||||
|
path_(path),
|
||||||
|
compression_(compression),
|
||||||
|
reader_func_(std::move(reader_func)),
|
||||||
|
shard_func_(std::move(shard_func)) {
|
||||||
|
input_->Ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::Dataset::~Dataset() { input_->Unref(); }
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase>
|
||||||
|
SnapshotDatasetV2Op::Dataset::MakeIteratorInternal(const string& prefix) const {
|
||||||
|
return absl::make_unique<Iterator>(
|
||||||
|
Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")});
|
||||||
|
}
|
||||||
|
|
||||||
|
const DataTypeVector& SnapshotDatasetV2Op::Dataset::output_dtypes() const {
|
||||||
|
return input_->output_dtypes();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<PartialTensorShape>&
|
||||||
|
SnapshotDatasetV2Op::Dataset::output_shapes() const {
|
||||||
|
return input_->output_shapes();
|
||||||
|
}
|
||||||
|
|
||||||
|
string SnapshotDatasetV2Op::Dataset::DebugString() const {
|
||||||
|
return name_utils::DatasetDebugString(kDatasetType);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 SnapshotDatasetV2Op::Dataset::Cardinality() const {
|
||||||
|
return input_->Cardinality();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::CheckExternalState() const {
|
||||||
|
return input_->CheckExternalState();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
|
||||||
|
SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const {
|
||||||
|
Node* input_graph_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||||
|
|
||||||
|
Node* path = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(path_, &path));
|
||||||
|
|
||||||
|
std::vector<Node*> reader_func_other_args;
|
||||||
|
DataTypeVector reader_func_other_args_types;
|
||||||
|
TF_RETURN_IF_ERROR(reader_func_->AddToGraph(ctx, b, &reader_func_other_args,
|
||||||
|
&reader_func_other_args_types));
|
||||||
|
|
||||||
|
std::vector<Node*> shard_func_other_args;
|
||||||
|
DataTypeVector shard_func_other_args_types;
|
||||||
|
TF_RETURN_IF_ERROR(shard_func_->AddToGraph(ctx, b, &shard_func_other_args,
|
||||||
|
&shard_func_other_args_types));
|
||||||
|
|
||||||
|
AttrValue compression_attr;
|
||||||
|
b->BuildAttrValue(compression_, &compression_attr);
|
||||||
|
|
||||||
|
AttrValue reader_func_attr;
|
||||||
|
b->BuildAttrValue(reader_func_->func(), &reader_func_attr);
|
||||||
|
|
||||||
|
AttrValue shard_func_attr;
|
||||||
|
b->BuildAttrValue(shard_func_->func(), &shard_func_attr);
|
||||||
|
|
||||||
|
AttrValue reader_func_arguments_types_attr;
|
||||||
|
b->BuildAttrValue(reader_func_other_args_types,
|
||||||
|
&reader_func_arguments_types_attr);
|
||||||
|
|
||||||
|
AttrValue shard_func_arguments_types_attr;
|
||||||
|
b->BuildAttrValue(shard_func_other_args_types,
|
||||||
|
&shard_func_arguments_types_attr);
|
||||||
|
|
||||||
|
return b->AddDataset(
|
||||||
|
this,
|
||||||
|
/*inputs=*/
|
||||||
|
{std::make_pair(0, input_graph_node), std::make_pair(1, path)},
|
||||||
|
/*list_inputs=*/
|
||||||
|
{std::make_pair(2, reader_func_other_args),
|
||||||
|
std::make_pair(3, shard_func_other_args)},
|
||||||
|
/*attrs=*/
|
||||||
|
{{kCompression, compression_attr},
|
||||||
|
{kReaderFunc, reader_func_attr},
|
||||||
|
{kShardFunc, shard_func_attr},
|
||||||
|
{kReaderFuncTarguments, reader_func_arguments_types_attr},
|
||||||
|
{kShardFuncTarguments, shard_func_arguments_types_attr}},
|
||||||
|
output);
|
||||||
|
}
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
|
||||||
|
: DatasetIterator<Dataset>(params),
|
||||||
|
index_(0),
|
||||||
|
hash_dir_(HashDirectory(dataset()->path_, dataset()->hash_)) {}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
|
||||||
|
IteratorContext* ctx) {
|
||||||
|
return ctx->env()->RecursivelyCreateDir(hash_dir_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::SaveInternal(
|
||||||
|
SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (iterator_ != nullptr) {
|
||||||
|
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_));
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIteratorMode),
|
||||||
|
static_cast<int64>(mode_)));
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name(kGraphHashDirectory), hash_dir_));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::RestoreInternal(
|
||||||
|
IteratorContext* ctx, IteratorStateReader* reader) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
|
||||||
|
if (reader->Contains(full_name(kIteratorMode))) {
|
||||||
|
TF_RETURN_IF_ERROR(InitializeIterator(ctx, reader));
|
||||||
|
return RestoreInput(ctx, reader, iterator_);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
|
||||||
|
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (iterator_ == nullptr) {
|
||||||
|
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
|
||||||
|
}
|
||||||
|
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||||
|
// should not be necessary, but the additional
|
||||||
|
// `{Reader,Writer,Passthrough}::kIteratorName` added to the prefix passed to
|
||||||
|
// `iterator_` when it was created prevents the model from identifying this
|
||||||
|
// iterator as the output of `iterator_`.
|
||||||
|
RecordStop(ctx);
|
||||||
|
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
|
index_++;
|
||||||
|
RecordStart(ctx);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
|
||||||
|
IteratorContext* ctx, IteratorStateReader* reader)
|
||||||
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
if (reader != nullptr) {
|
||||||
|
// Check whether the computed hash directory is the same.
|
||||||
|
tstring hash_dir;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(full_name(kGraphHashDirectory), &hash_dir));
|
||||||
|
if (hash_dir != hash_dir_) {
|
||||||
|
return errors::DataLoss(
|
||||||
|
"Dataset has changed while restoring from the checkpoint. Old hash "
|
||||||
|
"directory: ",
|
||||||
|
hash_dir, "; new hash directory: ", hash_dir_);
|
||||||
|
}
|
||||||
|
|
||||||
|
experimental::SnapshotMetadataRecord metadata;
|
||||||
|
bool file_exists;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
|
||||||
|
if (!file_exists) {
|
||||||
|
return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
|
||||||
|
" does not exist any more.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 iterator_mode;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(full_name(kIteratorMode), &iterator_mode));
|
||||||
|
mode_ = snapshot_util::Mode(iterator_mode);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_));
|
||||||
|
} else {
|
||||||
|
experimental::SnapshotMetadataRecord metadata;
|
||||||
|
bool file_exists;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
snapshot_util::ReadMetadataFile(hash_dir_, &metadata, &file_exists));
|
||||||
|
|
||||||
|
// `pending_snapshot_expiry_seconds` is a legacy option where we would not
|
||||||
|
// write snapshots that we think were still on-going. We decided that this
|
||||||
|
// would not be necessary as a feature for SnapshotV2, and we would always
|
||||||
|
// write a new snapshot regardless of whether someone else is currently
|
||||||
|
// writing one. Setting this to 0 ensures that all previous snapshots
|
||||||
|
// will be ignored and we will proceed to writing.
|
||||||
|
TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
|
||||||
|
/*mode_string=*/"", file_exists, &metadata,
|
||||||
|
/*pending_snapshot_expiry_seconds=*/0, &mode_));
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (mode_) {
|
||||||
|
case snapshot_util::READER:
|
||||||
|
iterator_ = absl::make_unique<Reader>(
|
||||||
|
Reader::Params{dataset(),
|
||||||
|
absl::StrCat(prefix(), Reader::kIteratorName)},
|
||||||
|
index_);
|
||||||
|
break;
|
||||||
|
case snapshot_util::WRITER:
|
||||||
|
iterator_ = absl::make_unique<Writer>(Writer::Params{
|
||||||
|
dataset(), absl::StrCat(prefix(), Writer::kIteratorName)});
|
||||||
|
break;
|
||||||
|
case snapshot_util::PASSTHROUGH:
|
||||||
|
iterator_ = absl::make_unique<Passthrough>(Passthrough::Params{
|
||||||
|
dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return iterator_->Initialize(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
|
||||||
|
int64 start_index)
|
||||||
|
: DatasetIterator<Dataset>(params), start_index_(start_index) {}
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
||||||
|
IteratorContext* ctx) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
|
||||||
|
|
||||||
|
tstring hash_dir = HashDirectory(dataset()->path_, dataset()->hash_);
|
||||||
|
bool metadata_file_exists;
|
||||||
|
experimental::SnapshotMetadataRecord metadata;
|
||||||
|
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(hash_dir, &metadata,
|
||||||
|
&metadata_file_exists));
|
||||||
|
|
||||||
|
auto run_dir = io::JoinPath(hash_dir, metadata.run_id());
|
||||||
|
|
||||||
|
std::vector<std::string> snapshot_shard_dirs;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
|
||||||
|
io::JoinPath(run_dir,
|
||||||
|
absl::StrFormat("%s%s", "*", kShardDirectorySuffix)),
|
||||||
|
&snapshot_shard_dirs));
|
||||||
|
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
|
||||||
|
|
||||||
|
DatasetBase* dataset_of_snapshot_files;
|
||||||
|
TF_RETURN_IF_ERROR(snapshot_util::Reader::MakeNestedDataset(
|
||||||
|
ctx->env(), snapshot_shard_dirs, dataset()->compression_,
|
||||||
|
kSnapshotFileFormatVersion, dataset()->output_dtypes(),
|
||||||
|
dataset()->output_shapes(), start_index_, &dataset_of_snapshot_files));
|
||||||
|
|
||||||
|
Tensor input_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(dataset_of_snapshot_files,
|
||||||
|
&input_dataset_tensor));
|
||||||
|
|
||||||
|
std::vector<Tensor> reader_input;
|
||||||
|
std::vector<Tensor> reader_output;
|
||||||
|
reader_input.push_back(std::move(input_dataset_tensor));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
|
||||||
|
ctx, std::move(reader_input), &reader_output));
|
||||||
|
if (reader_output.size() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"reader_func returns more than one argument.");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
|
||||||
|
|
||||||
|
// We need to take a reference here as we will use the input_ and
|
||||||
|
// its iterator.
|
||||||
|
input_->Ref();
|
||||||
|
|
||||||
|
return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::GetNextInternal(
|
||||||
|
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We do not need to checkpoint the reader as we are rebuilding the reader
|
||||||
|
// datasets from information that is already saved by the main iterator.
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::SaveInternal(
|
||||||
|
SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::RestoreInternal(
|
||||||
|
IteratorContext* ctx, IteratorStateReader* reader) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::Dataset::Iterator::Writer::Writer(const Params& params)
|
||||||
|
: DatasetIterator<Dataset>(params),
|
||||||
|
writers_closed_(false),
|
||||||
|
run_id_(0),
|
||||||
|
current_checkpoint_id_(0) {}
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::Dataset::Iterator::Writer::~Writer() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
StopWriterThreads(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SnapshotDatasetV2Op::Dataset::Iterator::Writer::StopWriterThreads(
|
||||||
|
bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
if (!writers_closed_) {
|
||||||
|
// Push the end of sequence signal to each of the threads to close files.
|
||||||
|
for (auto& writer_thread : writer_threads_) {
|
||||||
|
writer_thread.second->StopThread();
|
||||||
|
}
|
||||||
|
|
||||||
|
writer_threads_.clear();
|
||||||
|
writers_closed_ = mark_closed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
|
||||||
|
Env* env, bool finalized) {
|
||||||
|
DCHECK(!run_dir_.empty());
|
||||||
|
|
||||||
|
experimental::SnapshotMetadataRecord metadata;
|
||||||
|
metadata.set_creation_timestamp(EnvTime::NowMicros());
|
||||||
|
metadata.set_graph_hash(absl::StrFormat("%d", dataset()->hash_));
|
||||||
|
metadata.set_run_id(absl::StrFormat("%d", run_id_));
|
||||||
|
metadata.set_version(kSnapshotFileFormatVersion);
|
||||||
|
for (const auto& output_dtype : dataset()->output_dtypes()) {
|
||||||
|
metadata.add_dtype(output_dtype);
|
||||||
|
}
|
||||||
|
metadata.set_finalized(finalized);
|
||||||
|
tstring hash_directory = HashDirectory(dataset()->path_, dataset()->hash_);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(hash_directory));
|
||||||
|
return snapshot_util::WriteMetadataFile(hash_directory, &metadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
|
||||||
|
IteratorContext* ctx) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->shard_func_->Instantiate(ctx, &instantiated_shard_func_));
|
||||||
|
|
||||||
|
return dataset()->input_->MakeIterator(
|
||||||
|
ctx, this, strings::StrCat(prefix(), "::WriterIterator"), &input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
|
||||||
|
std::vector<Tensor>* tensors, int64* shard_index) {
|
||||||
|
std::vector<Tensor> output_tensors;
|
||||||
|
|
||||||
|
// Run the shard function
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
instantiated_shard_func_->RunInstantiated(*tensors, &output_tensors));
|
||||||
|
|
||||||
|
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||||
|
output_tensors[0].NumElements() != 1) {
|
||||||
|
return errors::InvalidArgument("`shard_func` must return a scalar int64.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writable files if we see an index bigger than our current files.
|
||||||
|
*shard_index = output_tensors[0].flat<int64>()(0);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
|
||||||
|
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) {
|
||||||
|
*end_of_sequence = false;
|
||||||
|
WriterThread* current_writer_thread;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<Tensor> output_tensors;
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
|
||||||
|
// We initialize late here because restoring from checkpoint comes after the
|
||||||
|
// the Initialize call. We cannot initialize within Initialize() because
|
||||||
|
// we cannot determine whether we should overwrite an existing metadata
|
||||||
|
// file or not before `RestoreInternal` is potentially called.
|
||||||
|
if (run_dir_.empty()) {
|
||||||
|
run_id_ = random::New64();
|
||||||
|
|
||||||
|
// Creates the run directory.
|
||||||
|
run_dir_ = RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_),
|
||||||
|
run_id_);
|
||||||
|
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
|
||||||
|
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writers have either encountered an error or are closed.
|
||||||
|
if (!writer_status_.ok() || writers_closed_) {
|
||||||
|
*end_of_sequence = true;
|
||||||
|
return writer_status_;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||||
|
|
||||||
|
// Finalize metadata file when we are at the end of the iterator.
|
||||||
|
if (*end_of_sequence) {
|
||||||
|
StopWriterThreads(/*mark_closed=*/true);
|
||||||
|
TF_RETURN_IF_ERROR(writer_status_);
|
||||||
|
return WriteMetadataFile(ctx->env(), /*finalized=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 shard_index = 0;
|
||||||
|
TF_RETURN_IF_ERROR(GetShardIndex(out_tensors, &shard_index));
|
||||||
|
|
||||||
|
// If the index does not exist, we will start a new thread.
|
||||||
|
if (writer_threads_.count(shard_index) == 0) {
|
||||||
|
const tstring snapshot_shard_directory =
|
||||||
|
SnapshotShardDirectory(run_dir_, shard_index);
|
||||||
|
auto thread_data = std::make_unique<WriterThread>(
|
||||||
|
this, ctx->env(), shard_index, snapshot_shard_directory,
|
||||||
|
current_checkpoint_id_);
|
||||||
|
writer_threads_.insert({shard_index, std::move(thread_data)});
|
||||||
|
}
|
||||||
|
current_writer_thread = writer_threads_[shard_index].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
current_writer_thread->EnqueueTensors(*out_tensors);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::SaveInternal(
|
||||||
|
SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name(kRunId), static_cast<int64>(run_id_)));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name(kCurrentCheckpointId),
|
||||||
|
static_cast<int64>(current_checkpoint_id_)));
|
||||||
|
|
||||||
|
StopWriterThreads(/*mark_closed=*/false);
|
||||||
|
writer_threads_.clear();
|
||||||
|
current_checkpoint_id_++;
|
||||||
|
|
||||||
|
return SaveInput(ctx, writer, input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
|
||||||
|
IteratorContext* ctx, IteratorStateReader* reader) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
int64 run_id_signed;
|
||||||
|
int64 current_checkpoint_id;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_signed));
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointId),
|
||||||
|
¤t_checkpoint_id));
|
||||||
|
|
||||||
|
run_id_ = static_cast<uint64>(run_id_signed);
|
||||||
|
run_dir_ =
|
||||||
|
RunDirectory(HashDirectory(dataset()->path_, dataset()->hash_), run_id_);
|
||||||
|
current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
|
||||||
|
|
||||||
|
return RestoreInput(ctx, reader, input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Passthrough(
|
||||||
|
const Params& params)
|
||||||
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Initialize(
|
||||||
|
IteratorContext* ctx) {
|
||||||
|
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::GetNextInternal(
|
||||||
|
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) {
|
||||||
|
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::SaveInternal(
|
||||||
|
SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||||
|
return SaveInput(ctx, writer, input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::RestoreInternal(
|
||||||
|
IteratorContext* ctx, IteratorStateReader* reader) {
|
||||||
|
return RestoreInput(ctx, reader, input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
SnapshotDatasetV2Op::SnapshotDatasetV2Op(OpKernelConstruction* ctx)
|
||||||
|
: UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
|
||||||
|
FunctionMetadata::Params reader_params;
|
||||||
|
FunctionMetadata::Params shard_params;
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, reader_params,
|
||||||
|
&reader_func_metadata_));
|
||||||
|
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, shard_params,
|
||||||
|
&shard_func_metadata_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
|
DatasetBase** output) {
|
||||||
|
tstring path;
|
||||||
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
|
||||||
|
|
||||||
|
// Computes the hash of the preceding items in the graph.
|
||||||
|
uint64 graph_hash;
|
||||||
|
GraphDef graph_def;
|
||||||
|
SerializationContext::Params params;
|
||||||
|
std::vector<std::pair<string, Tensor>> input_list;
|
||||||
|
params.input_list = &input_list;
|
||||||
|
params.external_state_policy =
|
||||||
|
SerializationContext::ExternalStatePolicy::kIgnore;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def));
|
||||||
|
OP_REQUIRES_OK(ctx, HashGraph(graph_def, &graph_hash));
|
||||||
|
|
||||||
|
std::unique_ptr<CapturedFunction> reader_func;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
CapturedFunction::Create(ctx, reader_func_metadata_,
|
||||||
|
kReaderFuncOtherArgs, &reader_func));
|
||||||
|
std::unique_ptr<CapturedFunction> shard_func;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
CapturedFunction::Create(ctx, shard_func_metadata_,
|
||||||
|
kShardFuncOtherArgs, &shard_func));
|
||||||
|
|
||||||
|
*output = new SnapshotDatasetV2Op::Dataset(
|
||||||
|
ctx, input, graph_hash, path, compression_, std::move(reader_func),
|
||||||
|
std::move(shard_func));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SnapshotDatasetV2").Device(DEVICE_CPU),
|
||||||
|
SnapshotDatasetV2Op);
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// ==== Legacy Snapshot Implementation (Deprecated) ====
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Defaults to 10 GiB per shard.
|
// Defaults to 10 GiB per shard.
|
||||||
|
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/op_requires.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
|
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
|
||||||
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/path.h"
|
||||||
|
#include "tensorflow/core/platform/random.h"
|
||||||
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace experimental {
|
||||||
|
|
||||||
|
const int64 kSnapshotFileFormatVersion = 1;
|
||||||
|
|
||||||
|
class SnapshotDatasetV2Op : public UnaryDatasetOpKernel {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kDatasetType = "Snapshot";
|
||||||
|
static constexpr const char* const kOutputTypes = "output_types";
|
||||||
|
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||||
|
static constexpr const char* const kCompression = "compression";
|
||||||
|
static constexpr const char* const kReaderFunc = "reader_func";
|
||||||
|
static constexpr const char* const kShardFunc = "shard_func";
|
||||||
|
static constexpr const char* const kReaderFuncOtherArgs =
|
||||||
|
"reader_func_other_args";
|
||||||
|
static constexpr const char* const kShardFuncOtherArgs =
|
||||||
|
"shard_func_other_args";
|
||||||
|
static constexpr const char* const kReaderFuncTarguments =
|
||||||
|
"Treader_func_args";
|
||||||
|
static constexpr const char* const kShardFuncTarguments = "Tshard_func_args";
|
||||||
|
|
||||||
|
explicit SnapshotDatasetV2Op(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
|
DatasetBase** output) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Dataset;
|
||||||
|
|
||||||
|
const int graph_def_version_;
|
||||||
|
DataTypeVector output_types_;
|
||||||
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
|
|
||||||
|
std::string compression_;
|
||||||
|
|
||||||
|
std::shared_ptr<FunctionMetadata> reader_func_metadata_;
|
||||||
|
std::shared_ptr<FunctionMetadata> shard_func_metadata_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace experimental
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
|
@ -899,6 +899,26 @@ REGISTER_OP("SnapshotDataset")
|
|||||||
return shape_inference::ScalarShape(c);
|
return shape_inference::ScalarShape(c);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SnapshotDatasetV2")
|
||||||
|
.Input("input_dataset: variant")
|
||||||
|
.Input("path: string")
|
||||||
|
.Input("reader_func_other_args: Treader_func_args")
|
||||||
|
.Input("shard_func_other_args: Tshard_func_args")
|
||||||
|
.Output("handle: variant")
|
||||||
|
.Attr("output_types: list(type) >= 1")
|
||||||
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
|
.Attr("compression: string = ''")
|
||||||
|
.Attr("reader_func: func")
|
||||||
|
.Attr("shard_func: func")
|
||||||
|
.Attr("Treader_func_args: list(type) >= 0")
|
||||||
|
.Attr("Tshard_func_args: list(type) >= 0")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle unused;
|
||||||
|
// `path` should be a scalar.
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
return shape_inference::ScalarShape(c);
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("SqlDataset")
|
REGISTER_OP("SqlDataset")
|
||||||
.Input("driver_name: string")
|
.Input("driver_name: string")
|
||||||
.Input("data_source_name: string")
|
.Input("data_source_name: string")
|
||||||
|
@ -75,6 +75,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||||||
@@sample_from_datasets
|
@@sample_from_datasets
|
||||||
@@scan
|
@@scan
|
||||||
@@shuffle_and_repeat
|
@@shuffle_and_repeat
|
||||||
|
@@snapshot
|
||||||
@@take_while
|
@@take_while
|
||||||
@@to_variant
|
@@to_variant
|
||||||
@@unbatch
|
@@unbatch
|
||||||
@ -128,6 +129,7 @@ from tensorflow.python.data.experimental.ops.readers import SqlDataset
|
|||||||
from tensorflow.python.data.experimental.ops.resampling import rejection_resample
|
from tensorflow.python.data.experimental.ops.resampling import rejection_resample
|
||||||
from tensorflow.python.data.experimental.ops.scan_ops import scan
|
from tensorflow.python.data.experimental.ops.scan_ops import scan
|
||||||
from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
|
from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
|
||||||
|
from tensorflow.python.data.experimental.ops.snapshot import snapshot
|
||||||
from tensorflow.python.data.experimental.ops.stats_aggregator import StatsAggregator
|
from tensorflow.python.data.experimental.ops.stats_aggregator import StatsAggregator
|
||||||
from tensorflow.python.data.experimental.ops.stats_ops import bytes_produced_stats
|
from tensorflow.python.data.experimental.ops.stats_ops import bytes_produced_stats
|
||||||
from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
|
from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
|
||||||
|
@ -753,9 +753,12 @@ tf_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:string_ops",
|
"//tensorflow/python:string_ops",
|
||||||
|
"//tensorflow/python/data/experimental/ops:readers",
|
||||||
"//tensorflow/python/data/experimental/ops:snapshot",
|
"//tensorflow/python/data/experimental/ops:snapshot",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"//tensorflow/python/data/ops:readers",
|
"//tensorflow/python/data/ops:readers",
|
||||||
|
"//tensorflow/python/data/util:nest",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -199,7 +199,9 @@ tf_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":dataset_serialization_test_base",
|
":dataset_serialization_test_base",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,6 +34,97 @@ class SnapshotDatasetSerializationTest(
|
|||||||
dataset_serialization_test_base.DatasetSerializationTestBase,
|
dataset_serialization_test_base.DatasetSerializationTestBase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
def _build_snapshot_dataset(self, repeat=False):
|
||||||
|
|
||||||
|
def ds_fn():
|
||||||
|
self._snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
|
||||||
|
if not os.path.exists(self._snapshot_dir):
|
||||||
|
os.mkdir(self._snapshot_dir)
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(100)
|
||||||
|
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
if repeat:
|
||||||
|
dataset = dataset.repeat(2)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
return ds_fn
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testCheckpointBeforeEpochEndNoRepeat(self):
|
||||||
|
ds_fn = self._build_snapshot_dataset(repeat=False)
|
||||||
|
outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False)
|
||||||
|
self.assertSequenceEqual(outputs, range(50))
|
||||||
|
outputs.extend(
|
||||||
|
self.gen_outputs(ds_fn, [], 50, ckpt_saved=True, verify_exhausted=True))
|
||||||
|
self.assertSequenceEqual(outputs, range(100))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testCheckpointBeforeOneEpochWithReading(self):
|
||||||
|
ds_fn = self._build_snapshot_dataset(repeat=True)
|
||||||
|
|
||||||
|
# Generate 50 entries from iterator and save checkpoint.
|
||||||
|
outputs = self.gen_outputs(ds_fn, [], 50, verify_exhausted=False)
|
||||||
|
self.assertSequenceEqual(outputs, list(range(50)))
|
||||||
|
|
||||||
|
# Restore from checkpoint and produce the rest of the elements from the
|
||||||
|
# iterator.
|
||||||
|
t = self.gen_outputs(ds_fn, [], 150, ckpt_saved=True, verify_exhausted=True)
|
||||||
|
outputs.extend(t)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
outputs,
|
||||||
|
list(range(50)) + list(range(50, 100)) + list(range(100)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testCheckpointBeforeOneEpochThenRunAFewSteps(self):
|
||||||
|
ds_fn = self._build_snapshot_dataset(repeat=False)
|
||||||
|
outputs = self.gen_outputs(
|
||||||
|
ds_fn, [10], 20, verify_exhausted=False, save_checkpoint_at_end=False)
|
||||||
|
self.assertSequenceEqual(outputs, range(20))
|
||||||
|
|
||||||
|
outputs = outputs[:10]
|
||||||
|
outputs.extend(
|
||||||
|
self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True))
|
||||||
|
self.assertSequenceEqual(outputs, range(100))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testCheckpointAfterOneEpoch(self):
|
||||||
|
ds_fn = self._build_snapshot_dataset(repeat=True)
|
||||||
|
|
||||||
|
# Generate 110 entries from iterator and save checkpoint.
|
||||||
|
outputs = self.gen_outputs(ds_fn, [], 110, verify_exhausted=False)
|
||||||
|
self.assertSequenceEqual(outputs, list(range(100)) + list(range(10)))
|
||||||
|
|
||||||
|
# Restore from checkpoint and produce the rest of the elements from the
|
||||||
|
# iterator.
|
||||||
|
t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)
|
||||||
|
outputs.extend(t)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
outputs,
|
||||||
|
list(range(100)) + list(range(10)) + list(range(10, 100)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testCheckpointAfterOneEpochRunFewSteps(self):
|
||||||
|
ds_fn = self._build_snapshot_dataset(repeat=True)
|
||||||
|
|
||||||
|
# Generate 120 entries from iterator and save checkpoint at 110.
|
||||||
|
outputs = self.gen_outputs(
|
||||||
|
ds_fn, [110], 120, verify_exhausted=False, save_checkpoint_at_end=False)
|
||||||
|
self.assertSequenceEqual(outputs, list(range(100)) + list(range(20)))
|
||||||
|
|
||||||
|
# Restore from checkpoint and produce the rest of the elements from the
|
||||||
|
# iterator.
|
||||||
|
outputs = outputs[:110]
|
||||||
|
t = self.gen_outputs(ds_fn, [], 90, ckpt_saved=True, verify_exhausted=True)
|
||||||
|
outputs.extend(t)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
outputs,
|
||||||
|
list(range(100)) + list(range(10)) + list(range(10, 100)))
|
||||||
|
|
||||||
|
|
||||||
|
class LegacySnapshotDatasetSerializationTest(
|
||||||
|
dataset_serialization_test_base.DatasetSerializationTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
def _build_snapshot_dataset(self,
|
def _build_snapshot_dataset(self,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
repeat=False,
|
repeat=False,
|
||||||
|
@ -17,11 +17,13 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||||
from tensorflow.python.data.experimental.ops import snapshot
|
from tensorflow.python.data.experimental.ops import snapshot
|
||||||
@ -40,6 +42,285 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(SnapshotDatasetTest, self).setUp()
|
super(SnapshotDatasetTest, self).setUp()
|
||||||
|
tmpdir = self.get_temp_dir()
|
||||||
|
tmpdir = os.path.join(tmpdir, "snapshot")
|
||||||
|
os.mkdir(tmpdir)
|
||||||
|
self._snapshot_dir = tmpdir
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super(SnapshotDatasetTest, self).tearDown()
|
||||||
|
shutil.rmtree(self._snapshot_dir)
|
||||||
|
|
||||||
|
def createTFRecords(self, num_files=10, num_records=100):
|
||||||
|
self._num_files = num_files
|
||||||
|
self._num_records = num_records
|
||||||
|
self._test_filenames = self._createFiles()
|
||||||
|
|
||||||
|
def removeTFRecords(self):
|
||||||
|
for filename in self._test_filenames:
|
||||||
|
os.remove(filename)
|
||||||
|
self._test_filenames = []
|
||||||
|
self._num_files = None
|
||||||
|
self._num_records = None
|
||||||
|
|
||||||
|
def assertDatasetProducesSet(self, dataset, expected):
|
||||||
|
actual = []
|
||||||
|
next_fn = self.getNext(dataset)
|
||||||
|
for _ in range(len(expected)):
|
||||||
|
elem = self.evaluate(next_fn())
|
||||||
|
actual.append(elem)
|
||||||
|
self.assertCountEqual(actual, expected)
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
self.evaluate(next_fn())
|
||||||
|
|
||||||
|
def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
|
||||||
|
num_runs_per_fingerprint,
|
||||||
|
num_snapshot_shards_per_run):
|
||||||
|
dirlist_raw = os.listdir(directory)
|
||||||
|
dirlist = []
|
||||||
|
|
||||||
|
# Ignore the graphdef pbtxts we write for debugging purposes.
|
||||||
|
for i in range(len(dirlist_raw)):
|
||||||
|
if not dirlist_raw[i].endswith("-graph.pbtxt"):
|
||||||
|
dirlist.append(dirlist_raw[i])
|
||||||
|
|
||||||
|
self.assertLen(dirlist, num_fingerprints)
|
||||||
|
|
||||||
|
for i in range(num_fingerprints):
|
||||||
|
fingerprint_dir = os.path.join(directory, dirlist[i])
|
||||||
|
fingerprint_dir_list = sorted(os.listdir(fingerprint_dir))
|
||||||
|
self.assertLen(fingerprint_dir_list, num_runs_per_fingerprint + 1)
|
||||||
|
self.assertEqual(fingerprint_dir_list[num_runs_per_fingerprint],
|
||||||
|
"snapshot.metadata")
|
||||||
|
|
||||||
|
for j in range(num_runs_per_fingerprint):
|
||||||
|
run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j])
|
||||||
|
run_dirlist = sorted(os.listdir(run_dir))
|
||||||
|
self.assertLen(run_dirlist, num_snapshot_shards_per_run)
|
||||||
|
|
||||||
|
file_counter = 0
|
||||||
|
for filename in run_dirlist:
|
||||||
|
self.assertEqual(filename, "%08d.shard" % file_counter)
|
||||||
|
file_counter += 1
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testCreateSnapshotDataset(self):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors([1, 2, 3])
|
||||||
|
dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testReadSnapshotDatasetDefault(self):
|
||||||
|
self.createTFRecords()
|
||||||
|
filenames = self._test_filenames
|
||||||
|
expected = [
|
||||||
|
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||||
|
for f in range(0, 10)
|
||||||
|
for r in range(0, 100)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
self.assertDatasetProduces(dataset, expected)
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=1,
|
||||||
|
num_runs_per_fingerprint=1,
|
||||||
|
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
self.removeTFRecords()
|
||||||
|
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
self.assertDatasetProduces(dataset2, expected)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testReadSnapshotDatasetCustomShardFn(self):
|
||||||
|
self.createTFRecords()
|
||||||
|
filenames = self._test_filenames
|
||||||
|
expected = [
|
||||||
|
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||||
|
for f in range(0, 10)
|
||||||
|
for r in range(0, 100)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset = dataset.apply(
|
||||||
|
snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: np.int64(0)))
|
||||||
|
self.assertDatasetProduces(dataset, expected)
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=1,
|
||||||
|
num_runs_per_fingerprint=1,
|
||||||
|
num_snapshot_shards_per_run=1)
|
||||||
|
|
||||||
|
self.removeTFRecords()
|
||||||
|
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset2 = dataset2.apply(
|
||||||
|
snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: 0))
|
||||||
|
self.assertDatasetProduces(dataset2, expected)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testReadSnapshotDatasetCustomReaderFn(self):
|
||||||
|
self.createTFRecords()
|
||||||
|
filenames = self._test_filenames
|
||||||
|
expected = [
|
||||||
|
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||||
|
for f in range(0, 10)
|
||||||
|
for r in range(0, 100)
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset = dataset.apply(
|
||||||
|
snapshot.snapshot(
|
||||||
|
self._snapshot_dir,
|
||||||
|
reader_func=(
|
||||||
|
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
|
||||||
|
lambda x: x,
|
||||||
|
cycle_length=4,
|
||||||
|
num_parallel_calls=4))))
|
||||||
|
self.assertDatasetProduces(dataset, expected)
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=1,
|
||||||
|
num_runs_per_fingerprint=1,
|
||||||
|
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
self.removeTFRecords()
|
||||||
|
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset2 = dataset2.apply(
|
||||||
|
snapshot.snapshot(
|
||||||
|
self._snapshot_dir,
|
||||||
|
reader_func=(
|
||||||
|
lambda ds: ds.interleave( # pylint:disable=g-long-lambda
|
||||||
|
lambda x: x,
|
||||||
|
cycle_length=4,
|
||||||
|
num_parallel_calls=4))))
|
||||||
|
self.assertDatasetProducesSet(dataset2, expected)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testSnapshotDatasetInvalidShardFn(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(1000)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
dataset = dataset.apply(
|
||||||
|
snapshot.snapshot(
|
||||||
|
self._snapshot_dir, shard_func=lambda _: "invalid_fn"))
|
||||||
|
next_fn = self.getNext(dataset)
|
||||||
|
self.evaluate(next_fn())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testSnapshotDatasetInvalidReaderFn(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(1000)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
dataset = dataset.apply(
|
||||||
|
snapshot.snapshot(self._snapshot_dir, reader_func=lambda x: x + 1))
|
||||||
|
next_fn = self.getNext(dataset)
|
||||||
|
self.evaluate(next_fn())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testWriteSnapshotDatasetSimple(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
self.assertDatasetProduces(dataset, list(range(1000)))
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=1,
|
||||||
|
num_runs_per_fingerprint=1,
|
||||||
|
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testWriteSnapshotDatasetMultipleFingerprints(self):
|
||||||
|
dataset1 = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
self.assertDatasetProduces(dataset1, list(range(1000)))
|
||||||
|
|
||||||
|
dataset2 = dataset_ops.Dataset.range(2000)
|
||||||
|
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
self.assertDatasetProduces(dataset2, list(range(2000)))
|
||||||
|
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=2,
|
||||||
|
num_runs_per_fingerprint=1,
|
||||||
|
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self):
|
||||||
|
dataset1 = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
self.assertDatasetProduces(dataset1, list(range(1000)))
|
||||||
|
dataset2 = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
self.assertDatasetProduces(dataset2, list(range(1000)))
|
||||||
|
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=1,
|
||||||
|
num_runs_per_fingerprint=1,
|
||||||
|
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self):
|
||||||
|
dataset1 = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
next1 = self.getNext(dataset1)
|
||||||
|
for i in range(500):
|
||||||
|
self.assertEqual(i, self.evaluate(next1()))
|
||||||
|
|
||||||
|
dataset2 = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
next2 = self.getNext(dataset2)
|
||||||
|
for i in range(500):
|
||||||
|
self.assertEqual(i, self.evaluate(next2()))
|
||||||
|
|
||||||
|
for i in range(500, 1000):
|
||||||
|
self.assertEqual(i, self.evaluate(next1()))
|
||||||
|
self.assertEqual(i, self.evaluate(next2()))
|
||||||
|
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=1,
|
||||||
|
num_runs_per_fingerprint=2,
|
||||||
|
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testWriteSnapshotCustomShardFunction(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset = dataset.enumerate()
|
||||||
|
dataset = dataset.apply(
|
||||||
|
snapshot.snapshot(self._snapshot_dir, shard_func=lambda i, _: i % 2))
|
||||||
|
dataset = dataset.map(lambda _, elem: elem)
|
||||||
|
self.assertDatasetProduces(dataset, list(range(1000)))
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=1,
|
||||||
|
num_runs_per_fingerprint=1,
|
||||||
|
num_snapshot_shards_per_run=2)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testWriteSnapshotDatasetWithTuples(self):
|
||||||
|
dataset1 = dataset_ops.Dataset.range(0, 1000)
|
||||||
|
dataset2 = dataset_ops.Dataset.range(1000, 2000)
|
||||||
|
dataset3 = dataset_ops.Dataset.range(2000, 3000)
|
||||||
|
dataset4 = dataset_ops.Dataset.range(3000, 4000)
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.zip((dataset1, dataset2, dataset3, dataset4))
|
||||||
|
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||||
|
next1 = self.getNext(dataset)
|
||||||
|
for i in range(0, 1000):
|
||||||
|
self.assertEqual((i, i + 1000, i + 2000, i + 3000),
|
||||||
|
self.evaluate(next1()))
|
||||||
|
self.assertSnapshotDirectoryContains(
|
||||||
|
self._snapshot_dir,
|
||||||
|
num_fingerprints=1,
|
||||||
|
num_runs_per_fingerprint=1,
|
||||||
|
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
|
||||||
|
class LegacySnapshotDatasetTest(
|
||||||
|
reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(LegacySnapshotDatasetTest, self).setUp()
|
||||||
self.removeTFRecords()
|
self.removeTFRecords()
|
||||||
tmpdir = self.get_temp_dir()
|
tmpdir = self.get_temp_dir()
|
||||||
tmpdir = os.path.join(tmpdir, "snapshot")
|
tmpdir = os.path.join(tmpdir, "snapshot")
|
||||||
@ -47,7 +328,7 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||||||
self.snapshot_dir = tmpdir
|
self.snapshot_dir = tmpdir
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super(SnapshotDatasetTest, self).tearDown()
|
super(LegacySnapshotDatasetTest, self).tearDown()
|
||||||
shutil.rmtree(self.snapshot_dir)
|
shutil.rmtree(self.snapshot_dir)
|
||||||
|
|
||||||
def removeTFRecords(self):
|
def removeTFRecords(self):
|
||||||
@ -63,8 +344,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||||||
def makeSnapshotDirectory(self):
|
def makeSnapshotDirectory(self):
|
||||||
return self.snapshot_dir
|
return self.snapshot_dir
|
||||||
|
|
||||||
def assertSnapshotDirectoryContains(
|
def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
|
||||||
self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files):
|
num_runs_per_fp, num_snapshot_files):
|
||||||
dirlist_raw = os.listdir(directory)
|
dirlist_raw = os.listdir(directory)
|
||||||
dirlist = []
|
dirlist = []
|
||||||
|
|
||||||
@ -465,8 +746,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||||||
]),
|
]),
|
||||||
combinations.combine(threads=2, size=[1, 2]) +
|
combinations.combine(threads=2, size=[1, 2]) +
|
||||||
combinations.combine(threads=8, size=[1, 4, 8]))))
|
combinations.combine(threads=8, size=[1, 4, 8]))))
|
||||||
def testReadSnapshotBackAfterMultiThreadedWrite(
|
def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads,
|
||||||
self, compression, threads, size):
|
size):
|
||||||
self.setUpTFRecord()
|
self.setUpTFRecord()
|
||||||
filenames = self.test_filenames
|
filenames = self.test_filenames
|
||||||
|
|
||||||
|
@ -17,12 +17,16 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||||
|
from tensorflow.python.util import deprecation
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
COMPRESSION_GZIP = "GZIP"
|
COMPRESSION_GZIP = "GZIP"
|
||||||
COMPRESSION_SNAPPY = "SNAPPY"
|
COMPRESSION_SNAPPY = "SNAPPY"
|
||||||
@ -99,6 +103,8 @@ class _LegacySnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
|||||||
super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor)
|
super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
None, "Use `tf.data.experimental.snapshot(...)` instead.")
|
||||||
def legacy_snapshot(path,
|
def legacy_snapshot(path,
|
||||||
compression=None,
|
compression=None,
|
||||||
reader_path_prefix=None,
|
reader_path_prefix=None,
|
||||||
@ -186,3 +192,165 @@ def legacy_snapshot(path,
|
|||||||
snapshot_name=snapshot_name)
|
snapshot_name=snapshot_name)
|
||||||
|
|
||||||
return _apply_fn
|
return _apply_fn
|
||||||
|
|
||||||
|
|
||||||
|
class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||||
|
"""A dataset that allows saving and re-use of already processed data."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_dataset,
|
||||||
|
path,
|
||||||
|
shard_func,
|
||||||
|
compression=None,
|
||||||
|
reader_func=None,
|
||||||
|
pending_snapshot_expiry_seconds=None,
|
||||||
|
use_legacy_function=False):
|
||||||
|
|
||||||
|
if reader_func is None:
|
||||||
|
reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda
|
||||||
|
lambda x: x,
|
||||||
|
cycle_length=multiprocessing.cpu_count(),
|
||||||
|
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||||
|
|
||||||
|
self._input_dataset = input_dataset
|
||||||
|
self._path = path
|
||||||
|
self._compression = compression
|
||||||
|
|
||||||
|
self._reader_func = dataset_ops.StructuredFunctionWrapper(
|
||||||
|
reader_func,
|
||||||
|
self._transformation_name() + ".reader_func",
|
||||||
|
# Dataset of datasets of input elements
|
||||||
|
input_structure=dataset_ops.DatasetSpec(
|
||||||
|
dataset_ops.DatasetSpec(input_dataset.element_spec)),
|
||||||
|
use_legacy_function=use_legacy_function)
|
||||||
|
self._shard_func = dataset_ops.StructuredFunctionWrapper(
|
||||||
|
shard_func,
|
||||||
|
self._transformation_name() + ".shard_func",
|
||||||
|
dataset=input_dataset,
|
||||||
|
use_legacy_function=use_legacy_function)
|
||||||
|
|
||||||
|
if ((not self._shard_func.output_structure.is_compatible_with(
|
||||||
|
tensor_spec.TensorSpec([], dtypes.int32))) and
|
||||||
|
(not self._shard_func.output_structure.is_compatible_with(
|
||||||
|
tensor_spec.TensorSpec([], dtypes.int64)))):
|
||||||
|
raise TypeError(
|
||||||
|
"shard_func must return a 0-dimension tensor containing an int.")
|
||||||
|
|
||||||
|
variant_tensor = ged_ops.snapshot_dataset_v2(
|
||||||
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
|
path,
|
||||||
|
self._reader_func.function.captured_inputs,
|
||||||
|
self._shard_func.function.captured_inputs,
|
||||||
|
compression=compression,
|
||||||
|
reader_func=self._reader_func.function,
|
||||||
|
shard_func=self._shard_func.function,
|
||||||
|
**self._flat_structure)
|
||||||
|
super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
|
def _functions(self):
|
||||||
|
return [self._reader_func, self._shard_func]
|
||||||
|
|
||||||
|
def _transformation_name(self):
|
||||||
|
return "Dataset.snapshot"
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.snapshot")
|
||||||
|
def snapshot(path, compression="AUTO", reader_func=None, shard_func=None):
|
||||||
|
"""API to persist the output of the input dataset.
|
||||||
|
|
||||||
|
The snapshot API allows users to transparently persist the output of their
|
||||||
|
preprocessing pipeline to disk, and materialize the pre-processed data on a
|
||||||
|
different training run.
|
||||||
|
|
||||||
|
This API enables repeated preprocessing steps to be consolidated, and allows
|
||||||
|
re-use of already processed data, trading off disk storage and network
|
||||||
|
bandwidth for freeing up more valuable CPU resources and accelerator compute
|
||||||
|
time.
|
||||||
|
|
||||||
|
https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md
|
||||||
|
has detailed design documentation of this feature.
|
||||||
|
|
||||||
|
Users can specify various options to control the behavior of snapshot,
|
||||||
|
including how snapshots are read from and written to by passing in
|
||||||
|
user-defined functions to the `reader_func` and `shard_func` parameters.
|
||||||
|
|
||||||
|
`shard_func` is a user specified function that maps input elements to snapshot
|
||||||
|
shards.
|
||||||
|
|
||||||
|
Users may want to specify this function to control how snapshot files should
|
||||||
|
be written to disk. Below is an example of how a potential shard_func could
|
||||||
|
be written.
|
||||||
|
|
||||||
|
```python
|
||||||
|
dataset = ...
|
||||||
|
dataset = dataset.enumerate()
|
||||||
|
dataset = dataset.apply(tf.data.experimental.snapshot(
|
||||||
|
shard_func=lambda x, y: x % NUM_SHARDS, ...))
|
||||||
|
dataset = dataset.map(lambda x, y: y)
|
||||||
|
```
|
||||||
|
|
||||||
|
`reader_func` is a user specified function that accepts a single argument:
|
||||||
|
(1) a Dataset of Datasets, each representing a "split" of elements of the
|
||||||
|
original dataset. The cardinality of the input dataset matches the
|
||||||
|
number of the shards specified in the `shard_func` (see above). The function
|
||||||
|
should return a Dataset of elements of the original dataset.
|
||||||
|
|
||||||
|
Users may want specify this function to control how snapshot files should be
|
||||||
|
read from disk, including the amount of shuffling and parallelism.
|
||||||
|
|
||||||
|
Here is an example of a standard reader function a user can define. This
|
||||||
|
function enables both dataset shuffling and parallel reading of datasets:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def user_reader_func(datasets):
|
||||||
|
# shuffle the datasets splits
|
||||||
|
datasets = datasets.shuffle(NUM_CORES)
|
||||||
|
# read datasets in parallel and interleave their elements
|
||||||
|
return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
|
||||||
|
|
||||||
|
dataset = dataset.apply(tf.data.experimental.snapshot(
|
||||||
|
reader_func=user_reader_func))
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, snapshot parallelize reads by the number of cores available on
|
||||||
|
the system, but will not attempt to shuffle the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Required. A directory to use for storing / loading the snapshot to /
|
||||||
|
from.
|
||||||
|
compression: Optional. The type of compression to apply to the snapshot
|
||||||
|
written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None.
|
||||||
|
Defaults to AUTO, which attempts to pick an appropriate compression
|
||||||
|
algorithm for the dataset.
|
||||||
|
reader_func: Optional. A function to control how to read data from snapshot
|
||||||
|
shards.
|
||||||
|
shard_func: Optional. A function to control how to shard data when writing a
|
||||||
|
snapshot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Dataset` transformation function, which can be passed to
|
||||||
|
`tf.data.Dataset.apply`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _apply_fn(dataset):
|
||||||
|
"""Actual dataset transformation."""
|
||||||
|
if shard_func is None:
|
||||||
|
dataset = dataset.enumerate()
|
||||||
|
dataset = _SnapshotDataset(
|
||||||
|
input_dataset=dataset,
|
||||||
|
path=path,
|
||||||
|
compression=compression,
|
||||||
|
reader_func=reader_func,
|
||||||
|
# This will not do the right thing where the graph is built on a
|
||||||
|
# different machine than the executor (e.g. Cloud TPUs).
|
||||||
|
shard_func=lambda index, _: index % multiprocessing.cpu_count())
|
||||||
|
return dataset.map(lambda _, elem: elem)
|
||||||
|
else:
|
||||||
|
return _SnapshotDataset(
|
||||||
|
input_dataset=dataset,
|
||||||
|
path=path,
|
||||||
|
compression=compression,
|
||||||
|
reader_func=reader_func,
|
||||||
|
shard_func=shard_func)
|
||||||
|
|
||||||
|
return _apply_fn
|
||||||
|
@ -220,6 +220,10 @@ tf_module {
|
|||||||
name: "shuffle_and_repeat"
|
name: "shuffle_and_repeat"
|
||||||
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "snapshot"
|
||||||
|
argspec: "args=[\'path\', \'compression\', \'reader_func\', \'shard_func\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "take_while"
|
name: "take_while"
|
||||||
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -3992,6 +3992,10 @@ tf_module {
|
|||||||
name: "SnapshotDataset"
|
name: "SnapshotDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "SnapshotDatasetV2"
|
||||||
|
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "SobolSample"
|
name: "SobolSample"
|
||||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
|
@ -188,6 +188,10 @@ tf_module {
|
|||||||
name: "shuffle_and_repeat"
|
name: "shuffle_and_repeat"
|
||||||
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "snapshot"
|
||||||
|
argspec: "args=[\'path\', \'compression\', \'reader_func\', \'shard_func\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "take_while"
|
name: "take_while"
|
||||||
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -3992,6 +3992,10 @@ tf_module {
|
|||||||
name: "SnapshotDataset"
|
name: "SnapshotDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "SnapshotDatasetV2"
|
||||||
|
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "SobolSample"
|
name: "SobolSample"
|
||||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user