[tf.data] Implementation of tf.data.experimental.save and tf.data.experimental.load. The former makes it possible to save the output of a dataset, while the latter makes it possible to load a previously saved data.

Fixes: 38483
PiperOrigin-RevId: 315991164
Change-Id: I30da604fdd489902ff4771b685e413447d3e9e9d
This commit is contained in:
Jiri Simsa 2020-06-11 15:32:34 -07:00 committed by TensorFlower Gardener
parent 348b3e6224
commit 4d58a67a9f
18 changed files with 885 additions and 11 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "LoadDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SaveDataset"
visibility: HIDDEN
}

View File

@ -252,6 +252,23 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "io_ops",
srcs = ["io_ops.cc"],
hdrs = ["io_ops.h"],
deps = [
":snapshot_util",
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/framework:op_requires",
"//tensorflow/core/kernels/data:captured_function",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:name_utils",
],
)
tf_kernel_library(
name = "lmdb_dataset_op",
srcs = ["lmdb_dataset_op.cc"],
@ -538,7 +555,6 @@ cc_library(
"//tensorflow/core/platform:random",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
],
)
@ -576,7 +592,6 @@ tf_kernel_library(
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
],
)
@ -724,6 +739,7 @@ tf_kernel_library(
":group_by_reducer_dataset_op",
":group_by_window_dataset_op",
":ignore_errors_dataset_op",
":io_ops",
":lmdb_dataset_op",
":map_and_batch_dataset_op",
":matching_files_dataset_op",

View File

@ -0,0 +1,377 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/io_ops.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/stringprintf.h"
#include "tensorflow/core/protobuf/data/experimental/snapshot.pb.h"
namespace tensorflow {
namespace data {
namespace experimental {
SaveDatasetOp::SaveDatasetOp(OpKernelConstruction* ctx)
: HybridAsyncOpKernel(ctx, "tf_data_save_dataset") {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, /*params=*/{},
&func_metadata_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseShardFunc, &use_shard_func_));
}
Status SaveDatasetOp::DoCompute(OpKernelContext* ctx) {
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
tstring path;
TF_RETURN_IF_ERROR(ParseScalarArgument(ctx, kPath, &path));
// Create a run directory.
auto run_id = random::New64();
auto run_dir = snapshot_util::RunDirectory(path, run_id);
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir));
TF_RETURN_IF_ERROR(
WriteMetadataFile(ctx->env(), path, run_id, dataset->output_dtypes(),
/*num_elements=*/0, /*finalized=*/false));
std::unique_ptr<CapturedFunction> captured_func;
TF_RETURN_IF_ERROR(CapturedFunction::Create(
ctx, func_metadata_, kShardFuncOtherArgs, &captured_func));
uint64 num_elements = 0;
TF_RETURN_IF_ERROR(WriteData(ctx, dataset, std::move(captured_func), run_dir,
&num_elements));
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), path, run_id,
dataset->output_dtypes(), num_elements,
/*finalized=*/true));
return Status::OK();
}
Status SaveDatasetOp::WriteData(OpKernelContext* ctx, DatasetBase* dataset,
std::unique_ptr<CapturedFunction> captured_func,
const std::string& run_dir,
uint64* num_elements) {
IteratorContext::Params params(ctx);
auto function_handle_cache =
absl::make_unique<FunctionHandleCache>(params.flr);
params.function_handle_cache = function_handle_cache.get();
ResourceMgr resource_mgr;
params.resource_mgr = &resource_mgr;
CancellationManager cancellation_manager(ctx->cancellation_manager());
params.cancellation_manager = &cancellation_manager;
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
TF_RETURN_IF_ERROR(
captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(
dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr, "Save", &iterator));
mutex mu;
Status status;
absl::flat_hash_map<int64, std::unique_ptr<snapshot_util::AsyncWriter>>
writers;
while (true) {
if (ctx->cancellation_manager()->IsCancelled()) {
return errors::Cancelled("Operation was cancelled");
}
std::vector<Tensor> element;
bool end_of_input;
TF_RETURN_IF_ERROR(iterator->GetNext(&iter_ctx, &element, &end_of_input));
if (end_of_input) {
break;
}
(*num_elements)++;
// Run the shard function to compute the shard index.
int64 shard_index = -1;
TF_RETURN_IF_ERROR(GetShardIndex(
&iter_ctx, instantiated_captured_func.get(), element, &shard_index));
// If the index does not exist, we will start a new thread.
if (writers.count(shard_index) == 0) {
const auto snapshot_shard_directory =
snapshot_util::ShardDirectory(run_dir, shard_index);
auto writer_thread = std::make_unique<snapshot_util::AsyncWriter>(
ctx->env(), shard_index, snapshot_shard_directory,
/*checkpoint_id=*/0, compression_, kFileFormatVersion,
dataset->output_dtypes(), [&mu, &status](Status s) {
mutex_lock l(mu);
status.Update(s);
});
writers.insert({shard_index, std::move(writer_thread)});
}
writers[shard_index]->Write(element);
}
// Push the end of sequence signal to each of the threads to close files.
for (auto& writer : writers) {
writer.second->SignalEOF();
}
// Wait for the writer threads to join.
writers.clear();
return status;
}
Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx,
InstantiatedCapturedFunction* function,
const std::vector<Tensor>& element,
int64* shard_index) {
if (!use_shard_func_) {
*shard_index = (*shard_index + 1) % port::NumSchedulableCPUs();
return Status::OK();
}
std::vector<Tensor> output_tensors;
TF_RETURN_IF_ERROR(
function->RunWithBorrowedArgs(ctx, element, &output_tensors));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) {
return errors::InvalidArgument("`shard_func` must return a scalar int64.");
}
*shard_index = output_tensors[0].flat<int64>()(0);
return Status::OK();
}
Status SaveDatasetOp::WriteMetadataFile(Env* env, const std::string& path,
uint64 run_id,
const DataTypeVector& output_dtypes,
uint64 num_elements, bool finalized) {
SnapshotMetadataRecord metadata;
metadata.set_creation_timestamp(EnvTime::NowMicros());
metadata.set_run_id(strings::Printf("%llu", run_id));
metadata.set_version(kFileFormatVersion);
for (const auto& output_dtype : output_dtypes) {
metadata.add_dtype(output_dtype);
}
metadata.set_finalized(finalized);
metadata.set_num_elements(num_elements);
return snapshot_util::WriteMetadataFile(env, path, &metadata);
}
class LoadDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const tstring& path,
SnapshotMetadataRecord metadata, const std::string& compression,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
captured_func_(std::move(captured_func)),
compression_(compression),
metadata_(std::move(metadata)),
output_types_(output_types),
output_shapes_(output_shapes),
path_(path) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}
const DataTypeVector& output_dtypes() const override { return output_types_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType);
}
int64 Cardinality() const override { return metadata_.num_elements(); }
Status CheckExternalState() const override {
return captured_func_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* path_node = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(path_, &path_node));
std::vector<Node*> reader_func_other_args;
DataTypeVector reader_func_other_args_types;
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(
ctx, b, &reader_func_other_args, &reader_func_other_args_types));
// Attr: compression
AttrValue compression_attr;
b->BuildAttrValue(compression_, &compression_attr);
// Attr: reader_func
AttrValue reader_func_attr;
b->BuildAttrValue(captured_func_->func(), &reader_func_attr);
AttrValue reader_func_arguments_types_attr;
b->BuildAttrValue(reader_func_other_args_types,
&reader_func_arguments_types_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this, {std::make_pair(0, path_node)}, // Single tensor inputs.
{std::make_pair(1, reader_func_other_args)}, // Tensor list inputs.
{std::make_pair(kCompression, compression_attr),
std::make_pair(kReaderFunc, reader_func_attr),
std::make_pair(kReaderFuncTarguments,
reader_func_arguments_types_attr)}, // Attrs
output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
~Iterator() override { input_->Unref(); }
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_));
TF_RETURN_IF_ERROR(InitializeInput(ctx));
return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeUnknownRatioNode(std::move(args));
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
return errors::Unimplemented("Checkpointing is currently not supported.");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented("Checkpointing is currently not supported.");
}
private:
Status InitializeInput(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
auto run_dir = snapshot_util::RunDirectory(dataset()->path_,
dataset()->metadata_.run_id());
std::vector<std::string> snapshot_shard_dirs;
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
io::JoinPath(run_dir,
strings::Printf("%s%s", "*",
snapshot_util::kShardDirectorySuffix)),
&snapshot_shard_dirs));
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
DatasetBase* dataset_of_snapshot_files;
TF_RETURN_IF_ERROR(snapshot_util::Reader::MakeNestedDataset(
ctx->env(), snapshot_shard_dirs, dataset()->compression_,
dataset()->metadata_.version(), dataset()->output_dtypes(),
dataset()->output_shapes(), /*start_index=*/0,
&dataset_of_snapshot_files));
Tensor input_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(dataset_of_snapshot_files,
&input_dataset_tensor));
std::vector<Tensor> reader_input;
std::vector<Tensor> reader_output;
reader_input.push_back(std::move(input_dataset_tensor));
TF_RETURN_IF_ERROR(instantiated_captured_func_->Run(
ctx, std::move(reader_input), &reader_output));
if (reader_output.size() != 1) {
return errors::InvalidArgument(
"reader_func returns more than one argument.");
}
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(reader_output[0], &input_));
// We need to take a reference here as we will use the input_ and
// its iterator.
input_->Ref();
return Status::OK();
}
mutex mu_;
DatasetBase* input_ TF_GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
};
const std::unique_ptr<CapturedFunction> captured_func_;
const std::string compression_;
const SnapshotMetadataRecord metadata_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
const tstring path_;
};
LoadDatasetOp::LoadDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, /*params=*/{},
&func_metadata_));
}
void LoadDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
tstring path;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPath, &path));
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
ctx, CapturedFunction::Create(ctx, func_metadata_, kReaderFuncOtherArgs,
&captured_func));
bool metadata_file_exists;
experimental::SnapshotMetadataRecord metadata;
OP_REQUIRES_OK(ctx, snapshot_util::ReadMetadataFile(
ctx->env(), path, &metadata, &metadata_file_exists));
OP_REQUIRES(ctx, metadata_file_exists,
errors::NotFound("Could not find metadata file."));
*output =
new Dataset(ctx, path, std::move(metadata), compression_,
std::move(captured_func), output_types_, output_shapes_);
}
namespace {
REGISTER_KERNEL_BUILDER(Name("SaveDataset").Device(DEVICE_CPU), SaveDatasetOp);
REGISTER_KERNEL_BUILDER(Name("LoadDataset").Device(DEVICE_CPU), LoadDatasetOp);
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,95 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_IO_OPS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_IO_OPS_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace data {
namespace experimental {
// An operation that can save a dataset to one or more files.
class SaveDatasetOp : public HybridAsyncOpKernel {
public:
static constexpr const char* const kCompression = "compression";
static constexpr const char* const kPath = "path";
static constexpr const char* const kShardFunc = "shard_func";
static constexpr const char* const kShardFuncOtherArgs =
"shard_func_other_args";
static constexpr const char* const kUseShardFunc = "use_shard_func";
explicit SaveDatasetOp(OpKernelConstruction* ctx);
Status DoCompute(OpKernelContext* ctx) override;
private:
static constexpr const int kFileFormatVersion = 2;
Status ConsumeElement();
Status GetShardIndex(IteratorContext* ctx,
InstantiatedCapturedFunction* function,
const std::vector<Tensor>& element, int64* shard_index);
Status WriteData(OpKernelContext* ctx, DatasetBase* dataset,
std::unique_ptr<CapturedFunction> captured_func,
const std::string& run_dir, uint64* num_elements);
Status WriteMetadataFile(Env* env, const std::string& path, uint64 run_id,
const DataTypeVector& output_dtypes,
uint64 num_elements, bool finalized);
bool use_shard_func_;
std::string compression_;
std::shared_ptr<FunctionMetadata> func_metadata_;
};
// An operation that can load a dataset from one or more files.
class LoadDatasetOp : public DatasetOpKernel {
public:
static constexpr const char* const kCompression = "compression";
static constexpr const char* const kDatasetType = "Load";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kPath = "path";
static constexpr const char* const kReaderFunc = "reader_func";
static constexpr const char* const kReaderFuncOtherArgs =
"reader_func_other_args";
static constexpr const char* const kReaderFuncTarguments =
"Treader_func_args";
explicit LoadDatasetOp(OpKernelConstruction* ctx);
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
private:
class Dataset;
std::string compression_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
std::shared_ptr<FunctionMetadata> func_metadata_;
};
} // namespace experimental
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_IO_OPS_H_

View File

@ -514,7 +514,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
io::JoinPath(
run_dir,
absl::StrFormat("%s%s", "*", snapshot_util::kShardDirectorySuffix)),
strings::Printf("%s%s", "*", snapshot_util::kShardDirectorySuffix)),
&snapshot_shard_dirs));
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
@ -596,8 +596,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
experimental::SnapshotMetadataRecord metadata;
metadata.set_creation_timestamp(EnvTime::NowMicros());
metadata.set_graph_hash(absl::StrFormat("%d", dataset()->hash_));
metadata.set_run_id(absl::StrFormat("%d", run_id_));
metadata.set_graph_hash(strings::Printf("%llu", dataset()->hash_));
metadata.set_run_id(strings::Printf("%llu", run_id_));
metadata.set_version(kFileFormatVersion);
for (const auto& output_dtype : dataset()->output_dtypes()) {
metadata.add_dtype(output_dtype);

View File

@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"

View File

@ -18,7 +18,6 @@ limitations under the License.
#include <queue>
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -37,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/stringprintf.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/data/experimental/snapshot.pb.h"
@ -50,11 +50,11 @@ namespace snapshot_util {
CustomReader::kSnappyReaderOutputBufferSizeBytes;
std::string HashDirectory(const std::string& path, uint64 hash) {
return io::JoinPath(path, absl::StrFormat("%d", hash));
return io::JoinPath(path, strings::Printf("%llu", hash));
}
std::string RunDirectory(const std::string& hash_directory, uint64 run_id) {
return RunDirectory(hash_directory, absl::StrFormat("%d", run_id));
return RunDirectory(hash_directory, strings::Printf("%llu", run_id));
}
std::string RunDirectory(const std::string& hash_directory,
@ -63,13 +63,13 @@ std::string RunDirectory(const std::string& hash_directory,
}
std::string ShardDirectory(const std::string& run_directory, int64 shard_id) {
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", shard_id,
return io::JoinPath(run_directory, strings::Printf("%08llu%s", shard_id,
kShardDirectorySuffix));
}
std::string GetCheckpointFileName(const std::string& shard_directory,
uint64 checkpoint_id) {
return io::JoinPath(shard_directory,
absl::StrFormat("%08d.snapshot", checkpoint_id));
strings::Printf("%08llu.snapshot", checkpoint_id));
}
Status Writer::Create(Env* env, const std::string& filename,

View File

@ -914,6 +914,38 @@ REGISTER_OP("SnapshotDatasetV2")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("SaveDataset")
.Input("input_dataset: variant")
.Input("path: string")
.Input("shard_func_other_args: Tshard_func_args")
.Attr("compression: string = ''")
.Attr("shard_func: func")
.Attr("use_shard_func: bool = true")
.Attr("Tshard_func_args: list(type) >= 0")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `path` should be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("LoadDataset")
.Input("path: string")
.Input("reader_func_other_args: Treader_func_args")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("compression: string = ''")
.Attr("reader_func: func")
.Attr("Treader_func_args: list(type) >= 0")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `path` should be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("SqlDataset")
.Input("driver_name: string")
.Input("data_source_name: string")

View File

@ -26,6 +26,8 @@ message SnapshotMetadataRecord {
int64 version = 4;
// A list of tensor dtype corresponding to each element of the snapshot.
repeated .tensorflow.DataType dtype = 5;
// The number of elements in the snapshot.
int64 num_elements = 6;
bool finalized = 1000;
}

View File

@ -63,6 +63,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@group_by_window
@@ignore_errors
@@latency_stats
@@load
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
@ -73,6 +74,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@prefetch_to_device
@@rejection_resample
@@sample_from_datasets
@@save
@@scan
@@shuffle_and_repeat
@@snapshot
@ -114,6 +116,8 @@ from tensorflow.python.data.experimental.ops.grouping import Reducer
from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets
from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave
from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
from tensorflow.python.data.experimental.ops.io import load
from tensorflow.python.data.experimental.ops.io import save
from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.python.data.experimental.ops.optimization_options import MapVectorizationOptions

View File

@ -259,6 +259,18 @@ tf_py_test(
],
)
tf_py_test(
name = "io_test",
srcs = ["io_test.py"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:io",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
name = "make_batched_features_dataset_test",
size = "medium",

View File

@ -0,0 +1,87 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the `tf.data.experimental.{save,load}` operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import io
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
class IOTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self):
super(IOTest, self).setUp()
tmpdir = self.get_temp_dir()
tmpdir = os.path.join(tmpdir, "io_test")
os.mkdir(tmpdir)
self._test_dir = tmpdir
def tearDown(self):
super(IOTest, self).tearDown()
shutil.rmtree(self._test_dir)
@combinations.generate(
combinations.times(test_base.eager_only_combinations(),
combinations.combine(compression=[None, "GZIP"])))
def testBasic(self, compression):
dataset = dataset_ops.Dataset.range(42)
io.save(dataset, self._test_dir, compression=compression)
dataset2 = io.load(
self._test_dir, dataset.element_spec, compression=compression)
self.assertDatasetProduces(dataset2, range(42))
@combinations.generate(test_base.eager_only_combinations())
def testCardinality(self):
dataset = dataset_ops.Dataset.range(42)
io.save(dataset, self._test_dir)
dataset2 = io.load(self._test_dir, dataset.element_spec)
self.assertEqual(self.evaluate(dataset2.cardinality()), 42)
@combinations.generate(test_base.eager_only_combinations())
def testCustomShardFunction(self):
dataset = dataset_ops.Dataset.range(42)
io.save(dataset, self._test_dir, shard_func=lambda x: x // 21)
dataset2 = io.load(self._test_dir, dataset.element_spec)
expected = []
for i in range(21):
expected.extend([i, i + 21])
self.assertDatasetProduces(dataset2, expected)
@combinations.generate(test_base.eager_only_combinations())
def testCustomReaderFunction(self):
dataset = dataset_ops.Dataset.range(42)
io.save(dataset, self._test_dir, shard_func=lambda x: x % 7)
dataset2 = io.load(
self._test_dir,
dataset.element_spec,
reader_func=lambda x: x.flat_map(lambda y: y))
expected = []
for i in range(7):
expected.extend(range(i, 42, 7))
self.assertDatasetProduces(dataset2, expected)
if __name__ == "__main__":
test.main()

View File

@ -164,6 +164,18 @@ py_library(
],
)
py_library(
name = "io",
srcs = [
"io.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_library(
name = "iterator_ops",
srcs = [
@ -493,6 +505,7 @@ py_library(
":get_single_element",
":grouping",
":interleave_ops",
":io",
":map_defun",
":matching_files",
":optimization",

View File

@ -0,0 +1,205 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python API for save and loading a dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util.tf_export import tf_export
COMPRESSION_GZIP = "GZIP"
COMPRESSION_SNAPPY = "NONE"
@tf_export("data.experimental.save", v1=[])
def save(dataset, path, compression=None, shard_func=None):
"""Saves the content of the given dataset.
Example usage:
>>> import tempfile
>>> path = os.path.join(tempfile.gettempdir(), "saved_data")
>>> # Save a dataset
>>> dataset = tf.data.Dataset.range(2)
>>> tf.data.experimental.save(dataset, path)
>>> new_dataset = tf.data.experimental.load(path,
... tf.TensorSpec(shape=(), dtype=tf.int64))
>>> for elem in new_dataset:
... print(elem)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
The saved dataset is saved in multiple file "shards". By default, the dataset
output is divided to shards in a round-robin fashion but custom sharding can
be specified via the `shard_func` function. For example, you can save the
dataset to using a single shard as follows:
```python
dataset = make_dataset()
def custom_shard_func(element):
return 0
dataset = tf.data.experimental.save(
path="/path/to/data", ..., shard_func=custom_shard_func)
```
NOTE: The directory layout and file format used for saving the dataset is
considered an implementation detail and may change. For this reason, datasets
saved through `tf.data.experimental.save` should only be consumed through
`tf.data.experimental.load`, which is guaranteed to be backwards compatible.
Args:
dataset: The dataset to save.
path: Required. A directory to use for saving the dataset.
compression: Optional. The algorithm to use to compress data when writing
it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
shard_func: Optional. A function to control the mapping of dataset elements
to file shards. The function is expected to map elements of the input
dataset to int64 shard IDs. If present, the function will be traced and
executed as graph computation.
"""
if shard_func is None:
use_shard_func = False
shard_func = lambda *x: None # a dummy function that will not be used
else:
use_shard_func = True
wrapped_func = dataset_ops.StructuredFunctionWrapper(
shard_func,
"save()",
input_structure=dataset.element_spec,
add_to_graph=False)
path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
shard_func = wrapped_func.function
shard_func.add_to_graph(ops.get_default_graph())
# pylint: disable=protected-access
dataset = dataset._apply_options()
gen_experimental_dataset_ops.save_dataset(
dataset._variant_tensor,
path=path,
shard_func_other_args=shard_func.captured_inputs,
compression=compression,
shard_func=shard_func,
use_shard_func=use_shard_func)
class _LoadDataset(dataset_ops.DatasetSource):
"""A dataset that loads previously saved dataset."""
def __init__(self, path, element_spec, compression=None, reader_func=None):
if reader_func is None:
reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda
lambda x: x,
cycle_length=multiprocessing.cpu_count(),
num_parallel_calls=dataset_ops.AUTOTUNE)
self._path = path
self._element_spec = element_spec
self._compression = compression
self._reader_func = dataset_ops.StructuredFunctionWrapper(
reader_func,
"load()",
# Dataset of datasets of input elements
input_structure=dataset_ops.DatasetSpec(
dataset_ops.DatasetSpec(element_spec)))
variant_tensor = gen_experimental_dataset_ops.load_dataset(
path,
reader_func_other_args=self._reader_func.function.captured_inputs,
compression=compression,
reader_func=self._reader_func.function,
**self._flat_structure)
super(_LoadDataset, self).__init__(variant_tensor)
def _functions(self):
return [self._reader_func]
@property
def element_spec(self):
return self._element_spec
@tf_export("data.experimental.load", v1=[])
def load(path, element_spec, compression=None, reader_func=None):
"""Loads a previously saved dataset.
Example usage:
>>> import tempfile
>>> path = os.path.join(tempfile.gettempdir(), "saved_data")
>>> # Save a dataset
>>> dataset = tf.data.Dataset.range(2)
>>> tf.data.experimental.save(dataset, path)
>>> new_dataset = tf.data.experimental.load(path,
... tf.TensorSpec(shape=(), dtype=tf.int64))
>>> for elem in new_dataset:
... print(elem)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
Note that to load a previously saved dataset, you need to specify
`element_spec` -- a type signature of the elements of the saved dataset, which
can be obtained via `tf.data.Dataset.element_spec`. This requirement exists so
that shape inference of the loaded dataset does not need to perform I/O.
If the default option of sharding the saved dataset was used, the element
order of the saved dataset will be preserved when loading it.
The `reader_func` argument can be used to specify a custom order in which
elements should be loaded from the individual shards. The `reader_func` is
expected to take a single argument -- a dataset of datasets, each containing
elements of one of the shards -- and return a dataset of elements. For
example, the order of shards can be shuffled when loading them as follows:
```python
def custom_reader_func(datasets):
datasets = datasets.shuffle(NUM_SHARDS)
return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
dataset = tf.data.experimental.load(
path="/path/to/data", ..., reader_func=custom_reader_func)
```
Args:
path: Required. A path pointing to a previously saved dataset.
element_spec: Required. A nested structure of `tf.TypeSpec` objects matching
the structure of an element of the saved dataset and specifying the type
of individual element components.
compression: Optional. The algorithm to use to decompress the data when
reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
reader_func: Optional. A function to control how to read data from shards.
If present, the function will be traced and executed as graph computation.
Returns:
A `tf.data.Dataset` instance.
"""
return _LoadDataset(
path=path,
element_spec=element_spec,
compression=compression,
reader_func=reader_func)

View File

@ -2048,6 +2048,10 @@ tf_module {
name: "LoadAndRemapMatrix"
argspec: "args=[\'ckpt_path\', \'old_tensor_name\', \'row_remapping\', \'col_remapping\', \'initializing_values\', \'num_rows\', \'num_cols\', \'max_rows_in_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
}
member_method {
name: "LoadDataset"
argspec: "args=[\'path\', \'reader_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingADAMParameters"
argspec: "args=[\'parameters\', \'momenta\', \'velocities\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
@ -3744,6 +3748,10 @@ tf_module {
name: "Save"
argspec: "args=[\'filename\', \'tensor_names\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "SaveDataset"
argspec: "args=[\'input_dataset\', \'path\', \'shard_func_other_args\', \'shard_func\', \'compression\', \'use_shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\'], "
}
member_method {
name: "SaveSlices"
argspec: "args=[\'filename\', \'tensor_names\', \'shapes_and_slices\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -144,6 +144,10 @@ tf_module {
name: "latency_stats"
argspec: "args=[\'tag\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "load"
argspec: "args=[\'path\', \'element_spec\', \'compression\', \'reader_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "make_batched_features_dataset"
argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], "
@ -180,6 +184,10 @@ tf_module {
name: "sample_from_datasets"
argspec: "args=[\'datasets\', \'weights\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "save"
argspec: "args=[\'dataset\', \'path\', \'compression\', \'shard_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "scan"
argspec: "args=[\'initial_state\', \'scan_func\'], varargs=None, keywords=None, defaults=None"

View File

@ -2048,6 +2048,10 @@ tf_module {
name: "LoadAndRemapMatrix"
argspec: "args=[\'ckpt_path\', \'old_tensor_name\', \'row_remapping\', \'col_remapping\', \'initializing_values\', \'num_rows\', \'num_cols\', \'max_rows_in_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
}
member_method {
name: "LoadDataset"
argspec: "args=[\'path\', \'reader_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "LoadTPUEmbeddingADAMParameters"
argspec: "args=[\'parameters\', \'momenta\', \'velocities\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
@ -3744,6 +3748,10 @@ tf_module {
name: "Save"
argspec: "args=[\'filename\', \'tensor_names\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "SaveDataset"
argspec: "args=[\'input_dataset\', \'path\', \'shard_func_other_args\', \'shard_func\', \'compression\', \'use_shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\'], "
}
member_method {
name: "SaveSlices"
argspec: "args=[\'filename\', \'tensor_names\', \'shapes_and_slices\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "