Add dataset API functionality to snapshot_util::Reader

PiperOrigin-RevId: 307561078
Change-Id: I7b78c70dcd0ad66ee5f2c05f564d0f29ca8b34fc
This commit is contained in:
Frank Chen 2020-04-21 01:28:57 -07:00 committed by TensorFlower Gardener
parent 2d79e9922f
commit 46646156db
3 changed files with 213 additions and 0 deletions

View File

@ -526,6 +526,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels/data:name_utils",
"//tensorflow/core/platform:coding",
"//tensorflow/core/platform:random",
"//tensorflow/core/profiler/lib:traceme",

View File

@ -15,10 +15,14 @@ limitations under the License.
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
#include <queue>
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
@ -219,6 +223,198 @@ Status Reader::Create(Env* env, const std::string& filename,
return (*out_reader)->Initialize(env);
}
class Reader::Dataset : public DatasetBase {
public:
explicit Dataset(const std::string& filename, const std::string& compression,
const int64 version, const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes,
DatasetContext::Params params)
: DatasetBase(DatasetContext(std::move(params))),
filename_(filename),
compression_(compression),
version_(version),
dtypes_(dtypes),
shapes_(shapes) {}
const DataTypeVector& output_dtypes() const override { return dtypes_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return shapes_;
}
std::string DebugString() const override {
return "snapshot_util::Reader::Dataset";
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** node) const override {
// TODO(frankchn): Implement for serialization and checkpointing.
return Status::OK();
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(node_name(), prefix)});
}
private:
std::string filename_;
std::string compression_;
int64 version_;
DataTypeVector dtypes_;
std::vector<PartialTensorShape> shapes_;
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return Reader::Create(ctx->env(), dataset()->filename_,
dataset()->compression_, dataset()->version_,
dataset()->dtypes_, &reader_);
}
protected:
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
*end_of_sequence = false;
Status s = reader_->ReadTensors(out_tensors);
if (errors::IsOutOfRange(s)) {
*end_of_sequence = true;
return Status::OK();
}
return s;
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
// TODO(frankchn): Implement for serialization and checkpointing.
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
// TODO(frankchn): Implement for serialization and checkpointing.
return Status::OK();
}
private:
std::unique_ptr<Reader> reader_;
};
};
class Reader::NestedDataset : public DatasetBase {
public:
explicit NestedDataset(std::vector<DatasetBase*> datasets,
DatasetContext::Params params)
: DatasetBase(DatasetContext(std::move(params))), datasets_(datasets) {
dtypes_.push_back(DT_VARIANT);
gtl::InlinedVector<int64, 1> element_dim_sizes;
element_dim_sizes.push_back(1);
partial_shapes_.emplace_back(element_dim_sizes);
}
const DataTypeVector& output_dtypes() const override { return dtypes_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return partial_shapes_;
}
std::string DebugString() const override {
return "snapshot_util::Reader::NestedDataset";
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** node) const override {
// TODO(frankchn): Implement for serialization and checkpointing.
return Status::OK();
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(node_name(), prefix)});
}
private:
std::vector<DatasetBase*> datasets_;
DataTypeVector dtypes_;
std::vector<PartialTensorShape> partial_shapes_;
class Iterator : public DatasetIterator<NestedDataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<NestedDataset>(params), index_(0) {}
protected:
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
*end_of_sequence = dataset()->datasets_.size() == index_;
if (!*end_of_sequence) {
Tensor tensor(DT_VARIANT, TensorShape({}));
TF_RETURN_IF_ERROR(
StoreDatasetInVariantTensor(dataset()->datasets_[index_], &tensor));
out_tensors->clear();
out_tensors->push_back(std::move(tensor));
index_++;
}
return Status::OK();
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
// TODO(frankchn): Implement for serialization and checkpointing.
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
// TODO(frankchn): Implement for serialization and checkpointing.
return Status::OK();
}
private:
int64 index_;
};
};
Status Reader::MakeNestedDataset(Env* env,
const std::vector<std::string>& filenames,
const string& compression_type, int version,
const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes,
DatasetBase** output) {
std::vector<DatasetBase*> datasets;
datasets.reserve(filenames.size());
for (const auto& filename : filenames) {
datasets.push_back(new Dataset(
filename, compression_type, version, dtypes, shapes,
DatasetContext::Params{.type_string = "snapshot_util::Reader::Dataset",
.node_name = "snapshot_util_reader_Dataset"}));
}
*output = new NestedDataset(
datasets, DatasetContext::Params{
.type_string = "snapshot_util::Reader::NestedDataset",
.node_name = "snapshot_util_reader_NestedDataset"});
return Status::OK();
}
Reader::Reader(const std::string& filename, const string& compression_type,
int version, const DataTypeVector& dtypes)
: filename_(filename),

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/io/compression.h"
@ -119,6 +120,18 @@ class Reader {
const DataTypeVector& dtypes,
std::unique_ptr<Reader>* out_reader);
// Returns a nested dataset for a set of given snapshot file names.
//
// This function takes a vector of snapshot files, and returns a nested
// dataset. Each element within the nested dataset is itself a dataset, and
// contains all the elements written out to each individual snapshot file.
static Status MakeNestedDataset(Env* env,
const std::vector<std::string>& filenames,
const string& compression_type, int version,
const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes,
DatasetBase** output);
Status ReadTensors(std::vector<Tensor>* read_tensors);
private:
@ -150,6 +163,9 @@ class Reader {
int num_simple_ = 0;
int num_complex_ = 0;
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
class Dataset;
class NestedDataset;
};
Status WriteMetadataFile(const string& hash_dir,