From 9b3a29889f6533b2f3c8cba6eba9e2d653712750 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Wed, 18 Jul 2018 07:55:00 -0700 Subject: [PATCH] [tf.data] Add checkpointing for CsvDataset PiperOrigin-RevId: 205078174 --- .../contrib/data/kernels/csv_dataset_op.cc | 100 ++++++++++++++++-- .../python/kernel_tests/serialization/BUILD | 14 +++ .../csv_dataset_serialization_test.py | 73 +++++++++++++ 3 files changed, 179 insertions(+), 8 deletions(-) create mode 100644 tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index dadde705e1c..f7e3ed886c6 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -150,6 +150,7 @@ class CSVDatasetOp : public DatasetOpKernel { delim_(delim), na_value_(std::move(na_value)), use_compression_(!compression_type.empty()), + compression_type_(std::move(compression_type)), options_(options) {} std::unique_ptr MakeIteratorInternal( @@ -169,10 +170,45 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** output) const override { - // TODO(rachelim): Implement this - std::vector input_tensors; - TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); - return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + Node* filenames = nullptr; + Node* compression_type = nullptr; + Node* buffer_size = nullptr; + Node* header = nullptr; + Node* delim = nullptr; + Node* use_quote_delim = nullptr; + Node* na_value = nullptr; + Node* select_cols = nullptr; + + std::vector record_defaults; + record_defaults.reserve(record_defaults_.size()); + for (const Tensor& t : record_defaults_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + record_defaults.emplace_back(node); + } + + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + TF_RETURN_IF_ERROR( + b->AddScalar(options_.input_buffer_size, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddScalar(header_, &header)); + + string delim_string(1, delim_); + TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim)); + TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim)); + TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value)); + TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {std::make_pair(0, filenames), std::make_pair(1, compression_type), + std::make_pair(2, buffer_size), std::make_pair(3, header), + std::make_pair(4, delim), std::make_pair(5, use_quote_delim), + std::make_pair(6, na_value), + std::make_pair(7, select_cols)}, // Single tensor inputs + {std::make_pair(8, record_defaults)}, // Tensor list inputs + {}, output)); + return Status::OK(); } private: @@ -224,14 +260,58 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - // TODO(rachelim): Implement save - return errors::Unimplemented("CSVDataset: SaveInternal"); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); + // `input_stream_` is empty if + // 1. GetNext has not been called even once. + // 2. All files have been read and the iterator has been exhausted. + if (input_stream_ && num_buffer_reads_ > 0) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_)); + // If num_buffer_reads_ == 0, the buffer hasn't been filled even once. + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), + num_buffer_reads_)); + } + return Status::OK(); } + Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - // TODO(rachelim): Implement restore - return errors::Unimplemented("CSVDataset: RestoreInternal"); + ResetStreamsLocked(); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); + current_file_index_ = size_t(current_file_index); + // The keys "pos" and "num_buffer_reads" are written only if + // the iterator was saved with an open, partially read file. + if (reader->Contains(full_name("pos"))) { + int64 pos, num_buffer_reads; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"), + &num_buffer_reads)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + + num_buffer_reads_ = size_t(num_buffer_reads - 1); + + // Restores the most recently held buffer + Status s = input_stream_->SkipNBytes( + num_buffer_reads_ * dataset()->options_.input_buffer_size); + if (!s.ok() && !errors::IsOutOfRange(s)) { + // We might get out of range error here if the size of the file + // is not an exact multiple of the buffer size, and the last buffer + // read is < buffer_size. This is valid and we do not surface the + // error. + return s; + } + + Status s2 = FillBuffer(&buffer_); + if (!s2.ok() && !errors::IsOutOfRange(s2)) { + return s2; + } + pos_ = size_t(pos); + } + return Status::OK(); } private: @@ -533,6 +613,7 @@ class CSVDatasetOp : public DatasetOpKernel { Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { result->clear(); + ++num_buffer_reads_; Status s = input_stream_->ReadNBytes( dataset()->options_.input_buffer_size, result); @@ -712,6 +793,7 @@ class CSVDatasetOp : public DatasetOpKernel { } buffer_.clear(); pos_ = 0; + num_buffer_reads_ = 0; if (dataset()->header_) { // Read one line, but don't include it. Pass nullptrs as dummy // pointers to objects that shouldn't be invoked anyway @@ -737,6 +819,7 @@ class CSVDatasetOp : public DatasetOpKernel { string buffer_ GUARDED_BY(mu_); // Maintain our own buffer size_t pos_ GUARDED_BY( mu_); // Index into the buffer must be maintained between iters + size_t num_buffer_reads_ GUARDED_BY(mu_); std::shared_ptr random_access_input_stream_ GUARDED_BY(mu_); std::shared_ptr input_stream_ GUARDED_BY(mu_); @@ -755,6 +838,7 @@ class CSVDatasetOp : public DatasetOpKernel { const char delim_; const string na_value_; const bool use_compression_; + const string compression_type_; const io::ZlibCompressionOptions options_; }; // class Dataset diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 686788522ac..3c3f23f9a98 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -72,6 +72,20 @@ py_test( ], ) +py_test( + name = "csv_dataset_serialization_test", + size = "small", + srcs = ["csv_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + ], +) + py_test( name = "dataset_constructor_serialization_test", size = "medium", diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py new file mode 100644 index 00000000000..247f2046ea3 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py @@ -0,0 +1,73 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the CsvDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.platform import test + + +class CsvDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._num_cols = 7 + self._num_rows = 10 + self._num_epochs = 14 + self._num_outputs = self._num_rows * self._num_epochs + + inputs = [ + ",".join(str(self._num_cols * j + i) + for i in range(self._num_cols)) + for j in range(self._num_rows) + ] + contents = "\n".join(inputs).encode("utf-8") + + self._filename = os.path.join(self.get_temp_dir(), "file.csv") + self._compressed = os.path.join(self.get_temp_dir(), + "comp.csv") # GZip compressed + + with open(self._filename, "wb") as f: + f.write(contents) + with gzip.GzipFile(self._compressed, "wb") as f: + f.write(contents) + + def ds_func(self, **kwargs): + compression_type = kwargs.get("compression_type", None) + if compression_type == "GZIP": + filename = self._compressed + elif compression_type is None: + filename = self._filename + else: + raise ValueError("Invalid compression type:", compression_type) + + return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs) + + def testSerializationCore(self): + defs = [[0]] * self._num_cols + self.run_core_tests( + lambda: self.ds_func(record_defaults=defs, buffer_size=2), + lambda: self.ds_func(record_defaults=defs, buffer_size=12), + self._num_outputs) + + +if __name__ == "__main__": + test.main()