[tf.data] Implementation of tf.data.experimental.save
and tf.data.experimental.load
. The former makes it possible to save the output of a dataset, while the latter makes it possible to load a previously saved data.
Fixes: 38483 PiperOrigin-RevId: 315991164 Change-Id: I30da604fdd489902ff4771b685e413447d3e9e9d
This commit is contained in:
parent
348b3e6224
commit
4d58a67a9f
tensorflow
core
api_def/base_api
kernels/data/experimental
ops
protobuf/data/experimental
python/data/experimental
tools/api/golden
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "LoadDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SaveDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -252,6 +252,23 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "io_ops",
|
||||
srcs = ["io_ops.cc"],
|
||||
hdrs = ["io_ops.h"],
|
||||
deps = [
|
||||
":snapshot_util",
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/framework:op_requires",
|
||||
"//tensorflow/core/kernels/data:captured_function",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "lmdb_dataset_op",
|
||||
srcs = ["lmdb_dataset_op.cc"],
|
||||
@ -538,7 +555,6 @@ cc_library(
|
||||
"//tensorflow/core/platform:random",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
@ -576,7 +592,6 @@ tf_kernel_library(
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
@ -724,6 +739,7 @@ tf_kernel_library(
|
||||
":group_by_reducer_dataset_op",
|
||||
":group_by_window_dataset_op",
|
||||
":ignore_errors_dataset_op",
|
||||
":io_ops",
|
||||
":lmdb_dataset_op",
|
||||
":map_and_batch_dataset_op",
|
||||
":matching_files_dataset_op",
|
||||
|
377
tensorflow/core/kernels/data/experimental/io_ops.cc
Normal file
377
tensorflow/core/kernels/data/experimental/io_ops.cc
Normal file
@ -0,0 +1,377 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/data/experimental/io_ops.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_requires.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/stringprintf.h"
|
||||
#include "tensorflow/core/protobuf/data/experimental/snapshot.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace experimental {
|
||||
|
||||
SaveDatasetOp::SaveDatasetOp(OpKernelConstruction* ctx)
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_save_dataset") {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
|
||||
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, /*params=*/{},
|
||||
&func_metadata_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseShardFunc, &use_shard_func_));
|
||||
}
|
||||
|
||||
Status SaveDatasetOp::DoCompute(OpKernelContext* ctx) {
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
|
||||
tstring path;
|
||||
TF_RETURN_IF_ERROR(ParseScalarArgument(ctx, kPath, &path));
|
||||
|
||||
// Create a run directory.
|
||||
auto run_id = random::New64();
|
||||
auto run_dir = snapshot_util::RunDirectory(path, run_id);
|
||||
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir));
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteMetadataFile(ctx->env(), path, run_id, dataset->output_dtypes(),
|
||||
/*num_elements=*/0, /*finalized=*/false));
|
||||
|
||||
std::unique_ptr<CapturedFunction> captured_func;
|
||||
TF_RETURN_IF_ERROR(CapturedFunction::Create(
|
||||
ctx, func_metadata_, kShardFuncOtherArgs, &captured_func));
|
||||
|
||||
uint64 num_elements = 0;
|
||||
TF_RETURN_IF_ERROR(WriteData(ctx, dataset, std::move(captured_func), run_dir,
|
||||
&num_elements));
|
||||
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), path, run_id,
|
||||
dataset->output_dtypes(), num_elements,
|
||||
/*finalized=*/true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SaveDatasetOp::WriteData(OpKernelContext* ctx, DatasetBase* dataset,
|
||||
std::unique_ptr<CapturedFunction> captured_func,
|
||||
const std::string& run_dir,
|
||||
uint64* num_elements) {
|
||||
IteratorContext::Params params(ctx);
|
||||
auto function_handle_cache =
|
||||
absl::make_unique<FunctionHandleCache>(params.flr);
|
||||
params.function_handle_cache = function_handle_cache.get();
|
||||
ResourceMgr resource_mgr;
|
||||
params.resource_mgr = &resource_mgr;
|
||||
CancellationManager cancellation_manager(ctx->cancellation_manager());
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
|
||||
TF_RETURN_IF_ERROR(
|
||||
captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
|
||||
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr, "Save", &iterator));
|
||||
|
||||
mutex mu;
|
||||
Status status;
|
||||
absl::flat_hash_map<int64, std::unique_ptr<snapshot_util::AsyncWriter>>
|
||||
writers;
|
||||
while (true) {
|
||||
if (ctx->cancellation_manager()->IsCancelled()) {
|
||||
return errors::Cancelled("Operation was cancelled");
|
||||
}
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_input;
|
||||
TF_RETURN_IF_ERROR(iterator->GetNext(&iter_ctx, &element, &end_of_input));
|
||||
if (end_of_input) {
|
||||
break;
|
||||
}
|
||||
(*num_elements)++;
|
||||
|
||||
// Run the shard function to compute the shard index.
|
||||
int64 shard_index = -1;
|
||||
TF_RETURN_IF_ERROR(GetShardIndex(
|
||||
&iter_ctx, instantiated_captured_func.get(), element, &shard_index));
|
||||
|
||||
// If the index does not exist, we will start a new thread.
|
||||
if (writers.count(shard_index) == 0) {
|
||||
const auto snapshot_shard_directory =
|
||||
snapshot_util::ShardDirectory(run_dir, shard_index);
|
||||
auto writer_thread = std::make_unique<snapshot_util::AsyncWriter>(
|
||||
ctx->env(), shard_index, snapshot_shard_directory,
|
||||
/*checkpoint_id=*/0, compression_, kFileFormatVersion,
|
||||
dataset->output_dtypes(), [&mu, &status](Status s) {
|
||||
mutex_lock l(mu);
|
||||
status.Update(s);
|
||||
});
|
||||
writers.insert({shard_index, std::move(writer_thread)});
|
||||
}
|
||||
writers[shard_index]->Write(element);
|
||||
}
|
||||
|
||||
// Push the end of sequence signal to each of the threads to close files.
|
||||
for (auto& writer : writers) {
|
||||
writer.second->SignalEOF();
|
||||
}
|
||||
// Wait for the writer threads to join.
|
||||
writers.clear();
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx,
|
||||
InstantiatedCapturedFunction* function,
|
||||
const std::vector<Tensor>& element,
|
||||
int64* shard_index) {
|
||||
if (!use_shard_func_) {
|
||||
*shard_index = (*shard_index + 1) % port::NumSchedulableCPUs();
|
||||
return Status::OK();
|
||||
}
|
||||
std::vector<Tensor> output_tensors;
|
||||
TF_RETURN_IF_ERROR(
|
||||
function->RunWithBorrowedArgs(ctx, element, &output_tensors));
|
||||
|
||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||
output_tensors[0].NumElements() != 1) {
|
||||
return errors::InvalidArgument("`shard_func` must return a scalar int64.");
|
||||
}
|
||||
*shard_index = output_tensors[0].flat<int64>()(0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SaveDatasetOp::WriteMetadataFile(Env* env, const std::string& path,
|
||||
uint64 run_id,
|
||||
const DataTypeVector& output_dtypes,
|
||||
uint64 num_elements, bool finalized) {
|
||||
SnapshotMetadataRecord metadata;
|
||||
metadata.set_creation_timestamp(EnvTime::NowMicros());
|
||||
metadata.set_run_id(strings::Printf("%llu", run_id));
|
||||
metadata.set_version(kFileFormatVersion);
|
||||
for (const auto& output_dtype : output_dtypes) {
|
||||
metadata.add_dtype(output_dtype);
|
||||
}
|
||||
metadata.set_finalized(finalized);
|
||||
metadata.set_num_elements(num_elements);
|
||||
return snapshot_util::WriteMetadataFile(env, path, &metadata);
|
||||
}
|
||||
|
||||
class LoadDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const tstring& path,
|
||||
SnapshotMetadataRecord metadata, const std::string& compression,
|
||||
std::unique_ptr<CapturedFunction> captured_func,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
captured_func_(std::move(captured_func)),
|
||||
compression_(compression),
|
||||
metadata_(std::move(metadata)),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes),
|
||||
path_(path) {}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(Iterator::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return metadata_.num_elements(); }
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return captured_func_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* path_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(path_, &path_node));
|
||||
|
||||
std::vector<Node*> reader_func_other_args;
|
||||
DataTypeVector reader_func_other_args_types;
|
||||
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(
|
||||
ctx, b, &reader_func_other_args, &reader_func_other_args_types));
|
||||
|
||||
// Attr: compression
|
||||
AttrValue compression_attr;
|
||||
b->BuildAttrValue(compression_, &compression_attr);
|
||||
|
||||
// Attr: reader_func
|
||||
AttrValue reader_func_attr;
|
||||
b->BuildAttrValue(captured_func_->func(), &reader_func_attr);
|
||||
|
||||
AttrValue reader_func_arguments_types_attr;
|
||||
b->BuildAttrValue(reader_func_other_args_types,
|
||||
&reader_func_arguments_types_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {std::make_pair(0, path_node)}, // Single tensor inputs.
|
||||
{std::make_pair(1, reader_func_other_args)}, // Tensor list inputs.
|
||||
{std::make_pair(kCompression, compression_attr),
|
||||
std::make_pair(kReaderFunc, reader_func_attr),
|
||||
std::make_pair(kReaderFuncTarguments,
|
||||
reader_func_arguments_types_attr)}, // Attrs
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
~Iterator() override { input_->Unref(); }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_));
|
||||
TF_RETURN_IF_ERROR(InitializeInput(ctx));
|
||||
return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeUnknownRatioNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("Checkpointing is currently not supported.");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented("Checkpointing is currently not supported.");
|
||||
}
|
||||
|
||||
private:
|
||||
Status InitializeInput(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
auto run_dir = snapshot_util::RunDirectory(dataset()->path_,
|
||||
dataset()->metadata_.run_id());
|
||||
|
||||
std::vector<std::string> snapshot_shard_dirs;
|
||||
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
|
||||
io::JoinPath(run_dir,
|
||||
strings::Printf("%s%s", "*",
|
||||
snapshot_util::kShardDirectorySuffix)),
|
||||
&snapshot_shard_dirs));
|
||||
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
|
||||
|
||||
DatasetBase* dataset_of_snapshot_files;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::Reader::MakeNestedDataset(
|
||||
ctx->env(), snapshot_shard_dirs, dataset()->compression_,
|
||||
dataset()->metadata_.version(), dataset()->output_dtypes(),
|
||||
dataset()->output_shapes(), /*start_index=*/0,
|
||||
&dataset_of_snapshot_files));
|
||||
|
||||
Tensor input_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(dataset_of_snapshot_files,
|
||||
&input_dataset_tensor));
|
||||
|
||||
std::vector<Tensor> reader_input;
|
||||
std::vector<Tensor> reader_output;
|
||||
reader_input.push_back(std::move(input_dataset_tensor));
|
||||
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->Run(
|
||||
ctx, std::move(reader_input), &reader_output));
|
||||
if (reader_output.size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"reader_func returns more than one argument.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetDatasetFromVariantTensor(reader_output[0], &input_));
|
||||
// We need to take a reference here as we will use the input_ and
|
||||
// its iterator.
|
||||
input_->Ref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
DatasetBase* input_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
|
||||
};
|
||||
|
||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||
const std::string compression_;
|
||||
const SnapshotMetadataRecord metadata_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const tstring path_;
|
||||
};
|
||||
|
||||
LoadDatasetOp::LoadDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, /*params=*/{},
|
||||
&func_metadata_));
|
||||
}
|
||||
|
||||
void LoadDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
|
||||
tstring path;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPath, &path));
|
||||
|
||||
std::unique_ptr<CapturedFunction> captured_func;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CapturedFunction::Create(ctx, func_metadata_, kReaderFuncOtherArgs,
|
||||
&captured_func));
|
||||
|
||||
bool metadata_file_exists;
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
OP_REQUIRES_OK(ctx, snapshot_util::ReadMetadataFile(
|
||||
ctx->env(), path, &metadata, &metadata_file_exists));
|
||||
|
||||
OP_REQUIRES(ctx, metadata_file_exists,
|
||||
errors::NotFound("Could not find metadata file."));
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, path, std::move(metadata), compression_,
|
||||
std::move(captured_func), output_types_, output_shapes_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("SaveDataset").Device(DEVICE_CPU), SaveDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("LoadDataset").Device(DEVICE_CPU), LoadDatasetOp);
|
||||
} // namespace
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
95
tensorflow/core/kernels/data/experimental/io_ops.h
Normal file
95
tensorflow/core/kernels/data/experimental/io_ops.h
Normal file
@ -0,0 +1,95 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_IO_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_IO_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace experimental {
|
||||
|
||||
// An operation that can save a dataset to one or more files.
|
||||
class SaveDatasetOp : public HybridAsyncOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kCompression = "compression";
|
||||
static constexpr const char* const kPath = "path";
|
||||
static constexpr const char* const kShardFunc = "shard_func";
|
||||
static constexpr const char* const kShardFuncOtherArgs =
|
||||
"shard_func_other_args";
|
||||
static constexpr const char* const kUseShardFunc = "use_shard_func";
|
||||
|
||||
explicit SaveDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
Status DoCompute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
static constexpr const int kFileFormatVersion = 2;
|
||||
|
||||
Status ConsumeElement();
|
||||
|
||||
Status GetShardIndex(IteratorContext* ctx,
|
||||
InstantiatedCapturedFunction* function,
|
||||
const std::vector<Tensor>& element, int64* shard_index);
|
||||
|
||||
Status WriteData(OpKernelContext* ctx, DatasetBase* dataset,
|
||||
std::unique_ptr<CapturedFunction> captured_func,
|
||||
const std::string& run_dir, uint64* num_elements);
|
||||
|
||||
Status WriteMetadataFile(Env* env, const std::string& path, uint64 run_id,
|
||||
const DataTypeVector& output_dtypes,
|
||||
uint64 num_elements, bool finalized);
|
||||
|
||||
bool use_shard_func_;
|
||||
std::string compression_;
|
||||
std::shared_ptr<FunctionMetadata> func_metadata_;
|
||||
};
|
||||
|
||||
// An operation that can load a dataset from one or more files.
|
||||
class LoadDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kCompression = "compression";
|
||||
static constexpr const char* const kDatasetType = "Load";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kPath = "path";
|
||||
static constexpr const char* const kReaderFunc = "reader_func";
|
||||
static constexpr const char* const kReaderFuncOtherArgs =
|
||||
"reader_func_other_args";
|
||||
static constexpr const char* const kReaderFuncTarguments =
|
||||
"Treader_func_args";
|
||||
|
||||
explicit LoadDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
|
||||
std::string compression_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
std::shared_ptr<FunctionMetadata> func_metadata_;
|
||||
};
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_IO_OPS_H_
|
@ -514,7 +514,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
||||
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
|
||||
io::JoinPath(
|
||||
run_dir,
|
||||
absl::StrFormat("%s%s", "*", snapshot_util::kShardDirectorySuffix)),
|
||||
strings::Printf("%s%s", "*", snapshot_util::kShardDirectorySuffix)),
|
||||
&snapshot_shard_dirs));
|
||||
std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
|
||||
|
||||
@ -596,8 +596,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
|
||||
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
metadata.set_creation_timestamp(EnvTime::NowMicros());
|
||||
metadata.set_graph_hash(absl::StrFormat("%d", dataset()->hash_));
|
||||
metadata.set_run_id(absl::StrFormat("%d", run_id_));
|
||||
metadata.set_graph_hash(strings::Printf("%llu", dataset()->hash_));
|
||||
metadata.set_run_id(strings::Printf("%llu", run_id_));
|
||||
metadata.set_version(kFileFormatVersion);
|
||||
for (const auto& output_dtype : dataset()->output_dtypes()) {
|
||||
metadata.add_dtype(output_dtype);
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/op_requires.h"
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include <queue>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -37,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/stringprintf.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/data/experimental/snapshot.pb.h"
|
||||
|
||||
@ -50,11 +50,11 @@ namespace snapshot_util {
|
||||
CustomReader::kSnappyReaderOutputBufferSizeBytes;
|
||||
|
||||
std::string HashDirectory(const std::string& path, uint64 hash) {
|
||||
return io::JoinPath(path, absl::StrFormat("%d", hash));
|
||||
return io::JoinPath(path, strings::Printf("%llu", hash));
|
||||
}
|
||||
|
||||
std::string RunDirectory(const std::string& hash_directory, uint64 run_id) {
|
||||
return RunDirectory(hash_directory, absl::StrFormat("%d", run_id));
|
||||
return RunDirectory(hash_directory, strings::Printf("%llu", run_id));
|
||||
}
|
||||
|
||||
std::string RunDirectory(const std::string& hash_directory,
|
||||
@ -63,13 +63,13 @@ std::string RunDirectory(const std::string& hash_directory,
|
||||
}
|
||||
|
||||
std::string ShardDirectory(const std::string& run_directory, int64 shard_id) {
|
||||
return io::JoinPath(run_directory, absl::StrFormat("%08d%s", shard_id,
|
||||
return io::JoinPath(run_directory, strings::Printf("%08llu%s", shard_id,
|
||||
kShardDirectorySuffix));
|
||||
}
|
||||
std::string GetCheckpointFileName(const std::string& shard_directory,
|
||||
uint64 checkpoint_id) {
|
||||
return io::JoinPath(shard_directory,
|
||||
absl::StrFormat("%08d.snapshot", checkpoint_id));
|
||||
strings::Printf("%08llu.snapshot", checkpoint_id));
|
||||
}
|
||||
|
||||
Status Writer::Create(Env* env, const std::string& filename,
|
||||
|
@ -914,6 +914,38 @@ REGISTER_OP("SnapshotDatasetV2")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("SaveDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("path: string")
|
||||
.Input("shard_func_other_args: Tshard_func_args")
|
||||
.Attr("compression: string = ''")
|
||||
.Attr("shard_func: func")
|
||||
.Attr("use_shard_func: bool = true")
|
||||
.Attr("Tshard_func_args: list(type) >= 0")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `path` should be a scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("LoadDataset")
|
||||
.Input("path: string")
|
||||
.Input("reader_func_other_args: Treader_func_args")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("compression: string = ''")
|
||||
.Attr("reader_func: func")
|
||||
.Attr("Treader_func_args: list(type) >= 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `path` should be a scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("SqlDataset")
|
||||
.Input("driver_name: string")
|
||||
.Input("data_source_name: string")
|
||||
|
@ -26,6 +26,8 @@ message SnapshotMetadataRecord {
|
||||
int64 version = 4;
|
||||
// A list of tensor dtype corresponding to each element of the snapshot.
|
||||
repeated .tensorflow.DataType dtype = 5;
|
||||
// The number of elements in the snapshot.
|
||||
int64 num_elements = 6;
|
||||
|
||||
bool finalized = 1000;
|
||||
}
|
||||
|
@ -63,6 +63,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@group_by_window
|
||||
@@ignore_errors
|
||||
@@latency_stats
|
||||
@@load
|
||||
@@make_batched_features_dataset
|
||||
@@make_csv_dataset
|
||||
@@make_saveable_from_iterator
|
||||
@ -73,6 +74,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@prefetch_to_device
|
||||
@@rejection_resample
|
||||
@@sample_from_datasets
|
||||
@@save
|
||||
@@scan
|
||||
@@shuffle_and_repeat
|
||||
@@snapshot
|
||||
@ -114,6 +116,8 @@ from tensorflow.python.data.experimental.ops.grouping import Reducer
|
||||
from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets
|
||||
from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave
|
||||
from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
|
||||
from tensorflow.python.data.experimental.ops.io import load
|
||||
from tensorflow.python.data.experimental.ops.io import save
|
||||
from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
|
||||
from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
|
||||
from tensorflow.python.data.experimental.ops.optimization_options import MapVectorizationOptions
|
||||
|
@ -259,6 +259,18 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "io_test",
|
||||
srcs = ["io_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:io",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "make_batched_features_dataset_test",
|
||||
size = "medium",
|
||||
|
87
tensorflow/python/data/experimental/kernel_tests/io_test.py
Normal file
87
tensorflow/python/data/experimental/kernel_tests/io_test.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the `tf.data.experimental.{save,load}` operations."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import io
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class IOTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(IOTest, self).setUp()
|
||||
tmpdir = self.get_temp_dir()
|
||||
tmpdir = os.path.join(tmpdir, "io_test")
|
||||
os.mkdir(tmpdir)
|
||||
self._test_dir = tmpdir
|
||||
|
||||
def tearDown(self):
|
||||
super(IOTest, self).tearDown()
|
||||
shutil.rmtree(self._test_dir)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.eager_only_combinations(),
|
||||
combinations.combine(compression=[None, "GZIP"])))
|
||||
def testBasic(self, compression):
|
||||
dataset = dataset_ops.Dataset.range(42)
|
||||
io.save(dataset, self._test_dir, compression=compression)
|
||||
dataset2 = io.load(
|
||||
self._test_dir, dataset.element_spec, compression=compression)
|
||||
self.assertDatasetProduces(dataset2, range(42))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testCardinality(self):
|
||||
dataset = dataset_ops.Dataset.range(42)
|
||||
io.save(dataset, self._test_dir)
|
||||
dataset2 = io.load(self._test_dir, dataset.element_spec)
|
||||
self.assertEqual(self.evaluate(dataset2.cardinality()), 42)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testCustomShardFunction(self):
|
||||
dataset = dataset_ops.Dataset.range(42)
|
||||
io.save(dataset, self._test_dir, shard_func=lambda x: x // 21)
|
||||
dataset2 = io.load(self._test_dir, dataset.element_spec)
|
||||
expected = []
|
||||
for i in range(21):
|
||||
expected.extend([i, i + 21])
|
||||
self.assertDatasetProduces(dataset2, expected)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testCustomReaderFunction(self):
|
||||
dataset = dataset_ops.Dataset.range(42)
|
||||
io.save(dataset, self._test_dir, shard_func=lambda x: x % 7)
|
||||
dataset2 = io.load(
|
||||
self._test_dir,
|
||||
dataset.element_spec,
|
||||
reader_func=lambda x: x.flat_map(lambda y: y))
|
||||
expected = []
|
||||
for i in range(7):
|
||||
expected.extend(range(i, 42, 7))
|
||||
self.assertDatasetProduces(dataset2, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -164,6 +164,18 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "io",
|
||||
srcs = [
|
||||
"io.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "iterator_ops",
|
||||
srcs = [
|
||||
@ -493,6 +505,7 @@ py_library(
|
||||
":get_single_element",
|
||||
":grouping",
|
||||
":interleave_ops",
|
||||
":io",
|
||||
":map_defun",
|
||||
":matching_files",
|
||||
":optimization",
|
||||
|
205
tensorflow/python/data/experimental/ops/io.py
Normal file
205
tensorflow/python/data/experimental/ops/io.py
Normal file
@ -0,0 +1,205 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python API for save and loading a dataset."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
COMPRESSION_GZIP = "GZIP"
|
||||
COMPRESSION_SNAPPY = "NONE"
|
||||
|
||||
|
||||
@tf_export("data.experimental.save", v1=[])
|
||||
def save(dataset, path, compression=None, shard_func=None):
|
||||
"""Saves the content of the given dataset.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> import tempfile
|
||||
>>> path = os.path.join(tempfile.gettempdir(), "saved_data")
|
||||
>>> # Save a dataset
|
||||
>>> dataset = tf.data.Dataset.range(2)
|
||||
>>> tf.data.experimental.save(dataset, path)
|
||||
>>> new_dataset = tf.data.experimental.load(path,
|
||||
... tf.TensorSpec(shape=(), dtype=tf.int64))
|
||||
>>> for elem in new_dataset:
|
||||
... print(elem)
|
||||
tf.Tensor(0, shape=(), dtype=int64)
|
||||
tf.Tensor(1, shape=(), dtype=int64)
|
||||
|
||||
The saved dataset is saved in multiple file "shards". By default, the dataset
|
||||
output is divided to shards in a round-robin fashion but custom sharding can
|
||||
be specified via the `shard_func` function. For example, you can save the
|
||||
dataset to using a single shard as follows:
|
||||
|
||||
```python
|
||||
dataset = make_dataset()
|
||||
def custom_shard_func(element):
|
||||
return 0
|
||||
dataset = tf.data.experimental.save(
|
||||
path="/path/to/data", ..., shard_func=custom_shard_func)
|
||||
```
|
||||
|
||||
NOTE: The directory layout and file format used for saving the dataset is
|
||||
considered an implementation detail and may change. For this reason, datasets
|
||||
saved through `tf.data.experimental.save` should only be consumed through
|
||||
`tf.data.experimental.load`, which is guaranteed to be backwards compatible.
|
||||
|
||||
Args:
|
||||
dataset: The dataset to save.
|
||||
path: Required. A directory to use for saving the dataset.
|
||||
compression: Optional. The algorithm to use to compress data when writing
|
||||
it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
|
||||
shard_func: Optional. A function to control the mapping of dataset elements
|
||||
to file shards. The function is expected to map elements of the input
|
||||
dataset to int64 shard IDs. If present, the function will be traced and
|
||||
executed as graph computation.
|
||||
"""
|
||||
|
||||
if shard_func is None:
|
||||
use_shard_func = False
|
||||
shard_func = lambda *x: None # a dummy function that will not be used
|
||||
else:
|
||||
use_shard_func = True
|
||||
|
||||
wrapped_func = dataset_ops.StructuredFunctionWrapper(
|
||||
shard_func,
|
||||
"save()",
|
||||
input_structure=dataset.element_spec,
|
||||
add_to_graph=False)
|
||||
|
||||
path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
|
||||
shard_func = wrapped_func.function
|
||||
shard_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
# pylint: disable=protected-access
|
||||
dataset = dataset._apply_options()
|
||||
gen_experimental_dataset_ops.save_dataset(
|
||||
dataset._variant_tensor,
|
||||
path=path,
|
||||
shard_func_other_args=shard_func.captured_inputs,
|
||||
compression=compression,
|
||||
shard_func=shard_func,
|
||||
use_shard_func=use_shard_func)
|
||||
|
||||
|
||||
class _LoadDataset(dataset_ops.DatasetSource):
|
||||
"""A dataset that loads previously saved dataset."""
|
||||
|
||||
def __init__(self, path, element_spec, compression=None, reader_func=None):
|
||||
|
||||
if reader_func is None:
|
||||
reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda
|
||||
lambda x: x,
|
||||
cycle_length=multiprocessing.cpu_count(),
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
|
||||
self._path = path
|
||||
self._element_spec = element_spec
|
||||
self._compression = compression
|
||||
|
||||
self._reader_func = dataset_ops.StructuredFunctionWrapper(
|
||||
reader_func,
|
||||
"load()",
|
||||
# Dataset of datasets of input elements
|
||||
input_structure=dataset_ops.DatasetSpec(
|
||||
dataset_ops.DatasetSpec(element_spec)))
|
||||
|
||||
variant_tensor = gen_experimental_dataset_ops.load_dataset(
|
||||
path,
|
||||
reader_func_other_args=self._reader_func.function.captured_inputs,
|
||||
compression=compression,
|
||||
reader_func=self._reader_func.function,
|
||||
**self._flat_structure)
|
||||
super(_LoadDataset, self).__init__(variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._reader_func]
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
@tf_export("data.experimental.load", v1=[])
|
||||
def load(path, element_spec, compression=None, reader_func=None):
|
||||
"""Loads a previously saved dataset.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> import tempfile
|
||||
>>> path = os.path.join(tempfile.gettempdir(), "saved_data")
|
||||
>>> # Save a dataset
|
||||
>>> dataset = tf.data.Dataset.range(2)
|
||||
>>> tf.data.experimental.save(dataset, path)
|
||||
>>> new_dataset = tf.data.experimental.load(path,
|
||||
... tf.TensorSpec(shape=(), dtype=tf.int64))
|
||||
>>> for elem in new_dataset:
|
||||
... print(elem)
|
||||
tf.Tensor(0, shape=(), dtype=int64)
|
||||
tf.Tensor(1, shape=(), dtype=int64)
|
||||
|
||||
|
||||
Note that to load a previously saved dataset, you need to specify
|
||||
`element_spec` -- a type signature of the elements of the saved dataset, which
|
||||
can be obtained via `tf.data.Dataset.element_spec`. This requirement exists so
|
||||
that shape inference of the loaded dataset does not need to perform I/O.
|
||||
|
||||
If the default option of sharding the saved dataset was used, the element
|
||||
order of the saved dataset will be preserved when loading it.
|
||||
|
||||
The `reader_func` argument can be used to specify a custom order in which
|
||||
elements should be loaded from the individual shards. The `reader_func` is
|
||||
expected to take a single argument -- a dataset of datasets, each containing
|
||||
elements of one of the shards -- and return a dataset of elements. For
|
||||
example, the order of shards can be shuffled when loading them as follows:
|
||||
|
||||
```python
|
||||
def custom_reader_func(datasets):
|
||||
datasets = datasets.shuffle(NUM_SHARDS)
|
||||
return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
|
||||
|
||||
dataset = tf.data.experimental.load(
|
||||
path="/path/to/data", ..., reader_func=custom_reader_func)
|
||||
```
|
||||
|
||||
Args:
|
||||
path: Required. A path pointing to a previously saved dataset.
|
||||
element_spec: Required. A nested structure of `tf.TypeSpec` objects matching
|
||||
the structure of an element of the saved dataset and specifying the type
|
||||
of individual element components.
|
||||
compression: Optional. The algorithm to use to decompress the data when
|
||||
reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
|
||||
reader_func: Optional. A function to control how to read data from shards.
|
||||
If present, the function will be traced and executed as graph computation.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` instance.
|
||||
"""
|
||||
|
||||
return _LoadDataset(
|
||||
path=path,
|
||||
element_spec=element_spec,
|
||||
compression=compression,
|
||||
reader_func=reader_func)
|
@ -2048,6 +2048,10 @@ tf_module {
|
||||
name: "LoadAndRemapMatrix"
|
||||
argspec: "args=[\'ckpt_path\', \'old_tensor_name\', \'row_remapping\', \'col_remapping\', \'initializing_values\', \'num_rows\', \'num_cols\', \'max_rows_in_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadDataset"
|
||||
argspec: "args=[\'path\', \'reader_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingADAMParameters"
|
||||
argspec: "args=[\'parameters\', \'momenta\', \'velocities\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
@ -3744,6 +3748,10 @@ tf_module {
|
||||
name: "Save"
|
||||
argspec: "args=[\'filename\', \'tensor_names\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SaveDataset"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'shard_func_other_args\', \'shard_func\', \'compression\', \'use_shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SaveSlices"
|
||||
argspec: "args=[\'filename\', \'tensor_names\', \'shapes_and_slices\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -144,6 +144,10 @@ tf_module {
|
||||
name: "latency_stats"
|
||||
argspec: "args=[\'tag\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "load"
|
||||
argspec: "args=[\'path\', \'element_spec\', \'compression\', \'reader_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "make_batched_features_dataset"
|
||||
argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], "
|
||||
@ -180,6 +184,10 @@ tf_module {
|
||||
name: "sample_from_datasets"
|
||||
argspec: "args=[\'datasets\', \'weights\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'dataset\', \'path\', \'compression\', \'shard_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "scan"
|
||||
argspec: "args=[\'initial_state\', \'scan_func\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -2048,6 +2048,10 @@ tf_module {
|
||||
name: "LoadAndRemapMatrix"
|
||||
argspec: "args=[\'ckpt_path\', \'old_tensor_name\', \'row_remapping\', \'col_remapping\', \'initializing_values\', \'num_rows\', \'num_cols\', \'max_rows_in_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadDataset"
|
||||
argspec: "args=[\'path\', \'reader_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LoadTPUEmbeddingADAMParameters"
|
||||
argspec: "args=[\'parameters\', \'momenta\', \'velocities\', \'num_shards\', \'shard_id\', \'table_id\', \'table_name\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'\', \'None\'], "
|
||||
@ -3744,6 +3748,10 @@ tf_module {
|
||||
name: "Save"
|
||||
argspec: "args=[\'filename\', \'tensor_names\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SaveDataset"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'shard_func_other_args\', \'shard_func\', \'compression\', \'use_shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SaveSlices"
|
||||
argspec: "args=[\'filename\', \'tensor_names\', \'shapes_and_slices\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user