[tf.data] Add checkpointing for CsvDataset

PiperOrigin-RevId: 205078174
This commit is contained in:
Rachel Lim 2018-07-18 07:55:00 -07:00 committed by TensorFlower Gardener
parent a46c9ab441
commit 9b3a29889f
3 changed files with 179 additions and 8 deletions

View File

@ -150,6 +150,7 @@ class CSVDatasetOp : public DatasetOpKernel {
delim_(delim), delim_(delim),
na_value_(std::move(na_value)), na_value_(std::move(na_value)),
use_compression_(!compression_type.empty()), use_compression_(!compression_type.empty()),
compression_type_(std::move(compression_type)),
options_(options) {} options_(options) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal( std::unique_ptr<IteratorBase> MakeIteratorInternal(
@ -169,10 +170,45 @@ class CSVDatasetOp : public DatasetOpKernel {
protected: protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override { Node** output) const override {
// TODO(rachelim): Implement this Node* filenames = nullptr;
std::vector<Node*> input_tensors; Node* compression_type = nullptr;
TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); Node* buffer_size = nullptr;
return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); Node* header = nullptr;
Node* delim = nullptr;
Node* use_quote_delim = nullptr;
Node* na_value = nullptr;
Node* select_cols = nullptr;
std::vector<Node*> record_defaults;
record_defaults.reserve(record_defaults_.size());
for (const Tensor& t : record_defaults_) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
record_defaults.emplace_back(node);
}
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
TF_RETURN_IF_ERROR(
b->AddScalar(options_.input_buffer_size, &buffer_size));
TF_RETURN_IF_ERROR(b->AddScalar(header_, &header));
string delim_string(1, delim_);
TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim));
TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim));
TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value));
TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols));
TF_RETURN_IF_ERROR(b->AddDataset(
this,
{std::make_pair(0, filenames), std::make_pair(1, compression_type),
std::make_pair(2, buffer_size), std::make_pair(3, header),
std::make_pair(4, delim), std::make_pair(5, use_quote_delim),
std::make_pair(6, na_value),
std::make_pair(7, select_cols)}, // Single tensor inputs
{std::make_pair(8, record_defaults)}, // Tensor list inputs
{}, output));
return Status::OK();
} }
private: private:
@ -224,14 +260,58 @@ class CSVDatasetOp : public DatasetOpKernel {
protected: protected:
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_); mutex_lock l(mu_);
// TODO(rachelim): Implement save TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
return errors::Unimplemented("CSVDataset: SaveInternal"); current_file_index_));
// `input_stream_` is empty if
// 1. GetNext has not been called even once.
// 2. All files have been read and the iterator has been exhausted.
if (input_stream_ && num_buffer_reads_ > 0) {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_));
// If num_buffer_reads_ == 0, the buffer hasn't been filled even once.
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"),
num_buffer_reads_));
}
return Status::OK();
} }
Status RestoreInternal(IteratorContext* ctx, Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(mu_); mutex_lock l(mu_);
// TODO(rachelim): Implement restore ResetStreamsLocked();
return errors::Unimplemented("CSVDataset: RestoreInternal"); int64 current_file_index;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
&current_file_index));
current_file_index_ = size_t(current_file_index);
// The keys "pos" and "num_buffer_reads" are written only if
// the iterator was saved with an open, partially read file.
if (reader->Contains(full_name("pos"))) {
int64 pos, num_buffer_reads;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"),
&num_buffer_reads));
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
num_buffer_reads_ = size_t(num_buffer_reads - 1);
// Restores the most recently held buffer
Status s = input_stream_->SkipNBytes(
num_buffer_reads_ * dataset()->options_.input_buffer_size);
if (!s.ok() && !errors::IsOutOfRange(s)) {
// We might get out of range error here if the size of the file
// is not an exact multiple of the buffer size, and the last buffer
// read is < buffer_size. This is valid and we do not surface the
// error.
return s;
}
Status s2 = FillBuffer(&buffer_);
if (!s2.ok() && !errors::IsOutOfRange(s2)) {
return s2;
}
pos_ = size_t(pos);
}
return Status::OK();
} }
private: private:
@ -533,6 +613,7 @@ class CSVDatasetOp : public DatasetOpKernel {
Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
result->clear(); result->clear();
++num_buffer_reads_;
Status s = input_stream_->ReadNBytes( Status s = input_stream_->ReadNBytes(
dataset()->options_.input_buffer_size, result); dataset()->options_.input_buffer_size, result);
@ -712,6 +793,7 @@ class CSVDatasetOp : public DatasetOpKernel {
} }
buffer_.clear(); buffer_.clear();
pos_ = 0; pos_ = 0;
num_buffer_reads_ = 0;
if (dataset()->header_) { if (dataset()->header_) {
// Read one line, but don't include it. Pass nullptrs as dummy // Read one line, but don't include it. Pass nullptrs as dummy
// pointers to objects that shouldn't be invoked anyway // pointers to objects that shouldn't be invoked anyway
@ -737,6 +819,7 @@ class CSVDatasetOp : public DatasetOpKernel {
string buffer_ GUARDED_BY(mu_); // Maintain our own buffer string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
size_t pos_ GUARDED_BY( size_t pos_ GUARDED_BY(
mu_); // Index into the buffer must be maintained between iters mu_); // Index into the buffer must be maintained between iters
size_t num_buffer_reads_ GUARDED_BY(mu_);
std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_ std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
GUARDED_BY(mu_); GUARDED_BY(mu_);
std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_); std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
@ -755,6 +838,7 @@ class CSVDatasetOp : public DatasetOpKernel {
const char delim_; const char delim_;
const string na_value_; const string na_value_;
const bool use_compression_; const bool use_compression_;
const string compression_type_;
const io::ZlibCompressionOptions options_; const io::ZlibCompressionOptions options_;
}; // class Dataset }; // class Dataset

View File

@ -72,6 +72,20 @@ py_test(
], ],
) )
py_test(
name = "csv_dataset_serialization_test",
size = "small",
srcs = ["csv_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
"//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
],
)
py_test( py_test(
name = "dataset_constructor_serialization_test", name = "dataset_constructor_serialization_test",
size = "medium", size = "medium",

View File

@ -0,0 +1,73 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the CsvDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.platform import test
class CsvDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def setUp(self):
self._num_cols = 7
self._num_rows = 10
self._num_epochs = 14
self._num_outputs = self._num_rows * self._num_epochs
inputs = [
",".join(str(self._num_cols * j + i)
for i in range(self._num_cols))
for j in range(self._num_rows)
]
contents = "\n".join(inputs).encode("utf-8")
self._filename = os.path.join(self.get_temp_dir(), "file.csv")
self._compressed = os.path.join(self.get_temp_dir(),
"comp.csv") # GZip compressed
with open(self._filename, "wb") as f:
f.write(contents)
with gzip.GzipFile(self._compressed, "wb") as f:
f.write(contents)
def ds_func(self, **kwargs):
compression_type = kwargs.get("compression_type", None)
if compression_type == "GZIP":
filename = self._compressed
elif compression_type is None:
filename = self._filename
else:
raise ValueError("Invalid compression type:", compression_type)
return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
def testSerializationCore(self):
defs = [[0]] * self._num_cols
self.run_core_tests(
lambda: self.ds_func(record_defaults=defs, buffer_size=2),
lambda: self.ds_func(record_defaults=defs, buffer_size=12),
self._num_outputs)
if __name__ == "__main__":
test.main()