Adding skip record functionality to snapshot utils.
PiperOrigin-RevId: 312200718 Change-Id: Icba0dfd19ffc6ddc0ca49f58d241beff7cd27714
This commit is contained in:
parent
d3886d23d7
commit
efd77d2e45
|
@ -62,7 +62,7 @@ Status Writer::Create(Env* env, const std::string& filename,
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Writer::Initialize(tensorflow::Env* env) {
|
Status Writer::Initialize(tensorflow::Env* env) {
|
||||||
TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_));
|
TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
|
||||||
#if defined(IS_SLIM_BUILD)
|
#if defined(IS_SLIM_BUILD)
|
||||||
if (compression_type_ != io::compression::kNone) {
|
if (compression_type_ != io::compression::kNone) {
|
||||||
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
|
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
|
||||||
|
@ -228,13 +228,14 @@ class Reader::Dataset : public DatasetBase {
|
||||||
explicit Dataset(const std::string& filename, const std::string& compression,
|
explicit Dataset(const std::string& filename, const std::string& compression,
|
||||||
const int64 version, const DataTypeVector& dtypes,
|
const int64 version, const DataTypeVector& dtypes,
|
||||||
const std::vector<PartialTensorShape>& shapes,
|
const std::vector<PartialTensorShape>& shapes,
|
||||||
DatasetContext::Params params)
|
const int64 start_index, DatasetContext::Params params)
|
||||||
: DatasetBase(DatasetContext(std::move(params))),
|
: DatasetBase(DatasetContext(std::move(params))),
|
||||||
filename_(filename),
|
filename_(filename),
|
||||||
compression_(compression),
|
compression_(compression),
|
||||||
version_(version),
|
version_(version),
|
||||||
dtypes_(dtypes),
|
dtypes_(dtypes),
|
||||||
shapes_(shapes) {}
|
shapes_(shapes),
|
||||||
|
start_index_(start_index) {}
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override { return dtypes_; }
|
const DataTypeVector& output_dtypes() const override { return dtypes_; }
|
||||||
|
|
||||||
|
@ -268,6 +269,7 @@ class Reader::Dataset : public DatasetBase {
|
||||||
int64 version_;
|
int64 version_;
|
||||||
DataTypeVector dtypes_;
|
DataTypeVector dtypes_;
|
||||||
std::vector<PartialTensorShape> shapes_;
|
std::vector<PartialTensorShape> shapes_;
|
||||||
|
const int64 start_index_;
|
||||||
|
|
||||||
class Iterator : public DatasetIterator<Dataset> {
|
class Iterator : public DatasetIterator<Dataset> {
|
||||||
public:
|
public:
|
||||||
|
@ -275,9 +277,10 @@ class Reader::Dataset : public DatasetBase {
|
||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return Reader::Create(ctx->env(), dataset()->filename_,
|
TF_RETURN_IF_ERROR(Reader::Create(
|
||||||
dataset()->compression_, dataset()->version_,
|
ctx->env(), dataset()->filename_, dataset()->compression_,
|
||||||
dataset()->dtypes_, &reader_);
|
dataset()->version_, dataset()->dtypes_, &reader_));
|
||||||
|
return reader_->SkipRecords(dataset()->start_index_);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -397,17 +400,32 @@ Status Reader::MakeNestedDataset(Env* env,
|
||||||
const string& compression_type, int version,
|
const string& compression_type, int version,
|
||||||
const DataTypeVector& dtypes,
|
const DataTypeVector& dtypes,
|
||||||
const std::vector<PartialTensorShape>& shapes,
|
const std::vector<PartialTensorShape>& shapes,
|
||||||
|
const int64 start_index,
|
||||||
DatasetBase** output) {
|
DatasetBase** output) {
|
||||||
std::vector<DatasetBase*> datasets;
|
std::vector<DatasetBase*> datasets;
|
||||||
|
|
||||||
datasets.reserve(filenames.size());
|
datasets.reserve(filenames.size());
|
||||||
for (const auto& filename : filenames) {
|
for (const auto& filename : filenames) {
|
||||||
|
// TODO(frankchn): The reading pattern could be controlled in a non-round
|
||||||
|
// robin fashion, so we cannot assume a round-robin manner when restoring.
|
||||||
|
int64 dataset_start_index = start_index / filenames.size();
|
||||||
|
if (start_index % filenames.size() > datasets.size()) {
|
||||||
|
dataset_start_index++;
|
||||||
|
}
|
||||||
|
|
||||||
datasets.push_back(
|
datasets.push_back(
|
||||||
new Dataset(filename, compression_type, version, dtypes, shapes,
|
new Dataset(filename, compression_type, version, dtypes, shapes,
|
||||||
|
dataset_start_index,
|
||||||
DatasetContext::Params({"snapshot_util::Reader::Dataset",
|
DatasetContext::Params({"snapshot_util::Reader::Dataset",
|
||||||
"snapshot_util_reader_Dataset"})));
|
"snapshot_util_reader_Dataset"})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rotate the vector such that the first dataset contains the next element
|
||||||
|
// to be produced.
|
||||||
|
std::rotate(datasets.begin(),
|
||||||
|
datasets.begin() + (start_index % filenames.size()),
|
||||||
|
datasets.end());
|
||||||
|
|
||||||
*output = new NestedDataset(
|
*output = new NestedDataset(
|
||||||
datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset",
|
datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset",
|
||||||
"snapshot_util_reader_NestedDataset"}));
|
"snapshot_util_reader_NestedDataset"}));
|
||||||
|
@ -463,6 +481,15 @@ Status Reader::Initialize(Env* env) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Reader::SkipRecords(int64 num_records) {
|
||||||
|
// TODO(frankchn): Optimize to not parse the entire Tensor and actually skip.
|
||||||
|
for (int i = 0; i < num_records; ++i) {
|
||||||
|
std::vector<Tensor> unused_tensors;
|
||||||
|
TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
|
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
|
||||||
|
|
|
@ -130,10 +130,13 @@ class Reader {
|
||||||
const string& compression_type, int version,
|
const string& compression_type, int version,
|
||||||
const DataTypeVector& dtypes,
|
const DataTypeVector& dtypes,
|
||||||
const std::vector<PartialTensorShape>& shapes,
|
const std::vector<PartialTensorShape>& shapes,
|
||||||
|
const int64 start_index,
|
||||||
DatasetBase** output);
|
DatasetBase** output);
|
||||||
|
|
||||||
Status ReadTensors(std::vector<Tensor>* read_tensors);
|
Status ReadTensors(std::vector<Tensor>* read_tensors);
|
||||||
|
|
||||||
|
Status SkipRecords(int64 num_records);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit Reader(const std::string& filename, const string& compression_type,
|
explicit Reader(const std::string& filename, const string& compression_type,
|
||||||
int version, const DataTypeVector& dtypes);
|
int version, const DataTypeVector& dtypes);
|
||||||
|
|
Loading…
Reference in New Issue