[tf.data] Add checkpointing for CsvDataset
PiperOrigin-RevId: 205078174
This commit is contained in:
parent
a46c9ab441
commit
9b3a29889f
@ -150,6 +150,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
|||||||
delim_(delim),
|
delim_(delim),
|
||||||
na_value_(std::move(na_value)),
|
na_value_(std::move(na_value)),
|
||||||
use_compression_(!compression_type.empty()),
|
use_compression_(!compression_type.empty()),
|
||||||
|
compression_type_(std::move(compression_type)),
|
||||||
options_(options) {}
|
options_(options) {}
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
@ -169,10 +170,45 @@ class CSVDatasetOp : public DatasetOpKernel {
|
|||||||
protected:
|
protected:
|
||||||
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
|
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
|
||||||
Node** output) const override {
|
Node** output) const override {
|
||||||
// TODO(rachelim): Implement this
|
Node* filenames = nullptr;
|
||||||
std::vector<Node*> input_tensors;
|
Node* compression_type = nullptr;
|
||||||
TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
|
Node* buffer_size = nullptr;
|
||||||
return errors::Unimplemented("CSVDataset: AsGraphDefInternal");
|
Node* header = nullptr;
|
||||||
|
Node* delim = nullptr;
|
||||||
|
Node* use_quote_delim = nullptr;
|
||||||
|
Node* na_value = nullptr;
|
||||||
|
Node* select_cols = nullptr;
|
||||||
|
|
||||||
|
std::vector<Node*> record_defaults;
|
||||||
|
record_defaults.reserve(record_defaults_.size());
|
||||||
|
for (const Tensor& t : record_defaults_) {
|
||||||
|
Node* node;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||||
|
record_defaults.emplace_back(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
b->AddScalar(options_.input_buffer_size, &buffer_size));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(header_, &header));
|
||||||
|
|
||||||
|
string delim_string(1, delim_);
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||||
|
this,
|
||||||
|
{std::make_pair(0, filenames), std::make_pair(1, compression_type),
|
||||||
|
std::make_pair(2, buffer_size), std::make_pair(3, header),
|
||||||
|
std::make_pair(4, delim), std::make_pair(5, use_quote_delim),
|
||||||
|
std::make_pair(6, na_value),
|
||||||
|
std::make_pair(7, select_cols)}, // Single tensor inputs
|
||||||
|
{std::make_pair(8, record_defaults)}, // Tensor list inputs
|
||||||
|
{}, output));
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -224,14 +260,58 @@ class CSVDatasetOp : public DatasetOpKernel {
|
|||||||
protected:
|
protected:
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
// TODO(rachelim): Implement save
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
|
||||||
return errors::Unimplemented("CSVDataset: SaveInternal");
|
current_file_index_));
|
||||||
|
// `input_stream_` is empty if
|
||||||
|
// 1. GetNext has not been called even once.
|
||||||
|
// 2. All files have been read and the iterator has been exhausted.
|
||||||
|
if (input_stream_ && num_buffer_reads_ > 0) {
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_));
|
||||||
|
// If num_buffer_reads_ == 0, the buffer hasn't been filled even once.
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"),
|
||||||
|
num_buffer_reads_));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
// TODO(rachelim): Implement restore
|
ResetStreamsLocked();
|
||||||
return errors::Unimplemented("CSVDataset: RestoreInternal");
|
int64 current_file_index;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
|
||||||
|
¤t_file_index));
|
||||||
|
current_file_index_ = size_t(current_file_index);
|
||||||
|
// The keys "pos" and "num_buffer_reads" are written only if
|
||||||
|
// the iterator was saved with an open, partially read file.
|
||||||
|
if (reader->Contains(full_name("pos"))) {
|
||||||
|
int64 pos, num_buffer_reads;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos));
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"),
|
||||||
|
&num_buffer_reads));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
|
||||||
|
|
||||||
|
num_buffer_reads_ = size_t(num_buffer_reads - 1);
|
||||||
|
|
||||||
|
// Restores the most recently held buffer
|
||||||
|
Status s = input_stream_->SkipNBytes(
|
||||||
|
num_buffer_reads_ * dataset()->options_.input_buffer_size);
|
||||||
|
if (!s.ok() && !errors::IsOutOfRange(s)) {
|
||||||
|
// We might get out of range error here if the size of the file
|
||||||
|
// is not an exact multiple of the buffer size, and the last buffer
|
||||||
|
// read is < buffer_size. This is valid and we do not surface the
|
||||||
|
// error.
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status s2 = FillBuffer(&buffer_);
|
||||||
|
if (!s2.ok() && !errors::IsOutOfRange(s2)) {
|
||||||
|
return s2;
|
||||||
|
}
|
||||||
|
pos_ = size_t(pos);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -533,6 +613,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
|||||||
|
|
||||||
Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
result->clear();
|
result->clear();
|
||||||
|
++num_buffer_reads_;
|
||||||
Status s = input_stream_->ReadNBytes(
|
Status s = input_stream_->ReadNBytes(
|
||||||
dataset()->options_.input_buffer_size, result);
|
dataset()->options_.input_buffer_size, result);
|
||||||
|
|
||||||
@ -712,6 +793,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
|||||||
}
|
}
|
||||||
buffer_.clear();
|
buffer_.clear();
|
||||||
pos_ = 0;
|
pos_ = 0;
|
||||||
|
num_buffer_reads_ = 0;
|
||||||
if (dataset()->header_) {
|
if (dataset()->header_) {
|
||||||
// Read one line, but don't include it. Pass nullptrs as dummy
|
// Read one line, but don't include it. Pass nullptrs as dummy
|
||||||
// pointers to objects that shouldn't be invoked anyway
|
// pointers to objects that shouldn't be invoked anyway
|
||||||
@ -737,6 +819,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
|||||||
string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
|
string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
|
||||||
size_t pos_ GUARDED_BY(
|
size_t pos_ GUARDED_BY(
|
||||||
mu_); // Index into the buffer must be maintained between iters
|
mu_); // Index into the buffer must be maintained between iters
|
||||||
|
size_t num_buffer_reads_ GUARDED_BY(mu_);
|
||||||
std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
|
std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
|
||||||
GUARDED_BY(mu_);
|
GUARDED_BY(mu_);
|
||||||
std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
|
std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
|
||||||
@ -755,6 +838,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
|||||||
const char delim_;
|
const char delim_;
|
||||||
const string na_value_;
|
const string na_value_;
|
||||||
const bool use_compression_;
|
const bool use_compression_;
|
||||||
|
const string compression_type_;
|
||||||
const io::ZlibCompressionOptions options_;
|
const io::ZlibCompressionOptions options_;
|
||||||
}; // class Dataset
|
}; // class Dataset
|
||||||
|
|
||||||
|
@ -72,6 +72,20 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "csv_dataset_serialization_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["csv_dataset_serialization_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["no_pip"],
|
||||||
|
deps = [
|
||||||
|
":dataset_serialization_test_base",
|
||||||
|
"//tensorflow/contrib/data/python/ops:readers",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "dataset_constructor_serialization_test",
|
name = "dataset_constructor_serialization_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for the CsvDataset serialization."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import gzip
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
|
||||||
|
from tensorflow.contrib.data.python.ops import readers
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class CsvDatasetSerializationTest(
|
||||||
|
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._num_cols = 7
|
||||||
|
self._num_rows = 10
|
||||||
|
self._num_epochs = 14
|
||||||
|
self._num_outputs = self._num_rows * self._num_epochs
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
",".join(str(self._num_cols * j + i)
|
||||||
|
for i in range(self._num_cols))
|
||||||
|
for j in range(self._num_rows)
|
||||||
|
]
|
||||||
|
contents = "\n".join(inputs).encode("utf-8")
|
||||||
|
|
||||||
|
self._filename = os.path.join(self.get_temp_dir(), "file.csv")
|
||||||
|
self._compressed = os.path.join(self.get_temp_dir(),
|
||||||
|
"comp.csv") # GZip compressed
|
||||||
|
|
||||||
|
with open(self._filename, "wb") as f:
|
||||||
|
f.write(contents)
|
||||||
|
with gzip.GzipFile(self._compressed, "wb") as f:
|
||||||
|
f.write(contents)
|
||||||
|
|
||||||
|
def ds_func(self, **kwargs):
|
||||||
|
compression_type = kwargs.get("compression_type", None)
|
||||||
|
if compression_type == "GZIP":
|
||||||
|
filename = self._compressed
|
||||||
|
elif compression_type is None:
|
||||||
|
filename = self._filename
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid compression type:", compression_type)
|
||||||
|
|
||||||
|
return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
|
||||||
|
|
||||||
|
def testSerializationCore(self):
|
||||||
|
defs = [[0]] * self._num_cols
|
||||||
|
self.run_core_tests(
|
||||||
|
lambda: self.ds_func(record_defaults=defs, buffer_size=2),
|
||||||
|
lambda: self.ds_func(record_defaults=defs, buffer_size=12),
|
||||||
|
self._num_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user