Adding skip record functionality to snapshot utils.

PiperOrigin-RevId: 312200718
Change-Id: Icba0dfd19ffc6ddc0ca49f58d241beff7cd27714
This commit is contained in:
Frank Chen 2020-05-18 19:21:37 -07:00 committed by TensorFlower Gardener
parent d3886d23d7
commit efd77d2e45
2 changed files with 36 additions and 6 deletions

View File

@ -62,7 +62,7 @@ Status Writer::Create(Env* env, const std::string& filename,
} }
Status Writer::Initialize(tensorflow::Env* env) { Status Writer::Initialize(tensorflow::Env* env) {
TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_)); TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
#if defined(IS_SLIM_BUILD) #if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) { if (compression_type_ != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
@ -228,13 +228,14 @@ class Reader::Dataset : public DatasetBase {
explicit Dataset(const std::string& filename, const std::string& compression, explicit Dataset(const std::string& filename, const std::string& compression,
const int64 version, const DataTypeVector& dtypes, const int64 version, const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes, const std::vector<PartialTensorShape>& shapes,
DatasetContext::Params params) const int64 start_index, DatasetContext::Params params)
: DatasetBase(DatasetContext(std::move(params))), : DatasetBase(DatasetContext(std::move(params))),
filename_(filename), filename_(filename),
compression_(compression), compression_(compression),
version_(version), version_(version),
dtypes_(dtypes), dtypes_(dtypes),
shapes_(shapes) {} shapes_(shapes),
start_index_(start_index) {}
const DataTypeVector& output_dtypes() const override { return dtypes_; } const DataTypeVector& output_dtypes() const override { return dtypes_; }
@ -268,6 +269,7 @@ class Reader::Dataset : public DatasetBase {
int64 version_; int64 version_;
DataTypeVector dtypes_; DataTypeVector dtypes_;
std::vector<PartialTensorShape> shapes_; std::vector<PartialTensorShape> shapes_;
const int64 start_index_;
class Iterator : public DatasetIterator<Dataset> { class Iterator : public DatasetIterator<Dataset> {
public: public:
@ -275,9 +277,10 @@ class Reader::Dataset : public DatasetBase {
: DatasetIterator<Dataset>(params) {} : DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return Reader::Create(ctx->env(), dataset()->filename_, TF_RETURN_IF_ERROR(Reader::Create(
dataset()->compression_, dataset()->version_, ctx->env(), dataset()->filename_, dataset()->compression_,
dataset()->dtypes_, &reader_); dataset()->version_, dataset()->dtypes_, &reader_));
return reader_->SkipRecords(dataset()->start_index_);
} }
protected: protected:
@ -397,17 +400,32 @@ Status Reader::MakeNestedDataset(Env* env,
const string& compression_type, int version, const string& compression_type, int version,
const DataTypeVector& dtypes, const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes, const std::vector<PartialTensorShape>& shapes,
const int64 start_index,
DatasetBase** output) { DatasetBase** output) {
std::vector<DatasetBase*> datasets; std::vector<DatasetBase*> datasets;
datasets.reserve(filenames.size()); datasets.reserve(filenames.size());
for (const auto& filename : filenames) { for (const auto& filename : filenames) {
// TODO(frankchn): The reading pattern could be controlled in a non-round
// robin fashion, so we cannot assume a round-robin manner when restoring.
int64 dataset_start_index = start_index / filenames.size();
if (start_index % filenames.size() > datasets.size()) {
dataset_start_index++;
}
datasets.push_back( datasets.push_back(
new Dataset(filename, compression_type, version, dtypes, shapes, new Dataset(filename, compression_type, version, dtypes, shapes,
dataset_start_index,
DatasetContext::Params({"snapshot_util::Reader::Dataset", DatasetContext::Params({"snapshot_util::Reader::Dataset",
"snapshot_util_reader_Dataset"}))); "snapshot_util_reader_Dataset"})));
} }
// Rotate the vector such that the first dataset contains the next element
// to be produced.
std::rotate(datasets.begin(),
datasets.begin() + (start_index % filenames.size()),
datasets.end());
*output = new NestedDataset( *output = new NestedDataset(
datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset", datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset",
"snapshot_util_reader_NestedDataset"})); "snapshot_util_reader_NestedDataset"}));
@ -463,6 +481,15 @@ Status Reader::Initialize(Env* env) {
return Status::OK(); return Status::OK();
} }
Status Reader::SkipRecords(int64 num_records) {
// TODO(frankchn): Optimize to not parse the entire Tensor and actually skip.
for (int i = 0; i < num_records; ++i) {
std::vector<Tensor> unused_tensors;
TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors));
}
return Status::OK();
}
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) { Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
profiler::TraceMe activity( profiler::TraceMe activity(
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },

View File

@ -130,10 +130,13 @@ class Reader {
const string& compression_type, int version, const string& compression_type, int version,
const DataTypeVector& dtypes, const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes, const std::vector<PartialTensorShape>& shapes,
const int64 start_index,
DatasetBase** output); DatasetBase** output);
Status ReadTensors(std::vector<Tensor>* read_tensors); Status ReadTensors(std::vector<Tensor>* read_tensors);
Status SkipRecords(int64 num_records);
private: private:
explicit Reader(const std::string& filename, const string& compression_type, explicit Reader(const std::string& filename, const string& compression_type,
int version, const DataTypeVector& dtypes); int version, const DataTypeVector& dtypes);