diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 175bfbd827c..6b217444740 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -21,6 +21,12 @@ limitations under the License. namespace tensorflow { namespace io { +namespace { +bool IsZlibCompressed(RecordWriterOptions options) { + return options.compression_type == RecordWriterOptions::ZLIB_COMPRESSION; +} +} // namespace + RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions( const string& compression_type) { RecordWriterOptions options; @@ -50,19 +56,20 @@ RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions( RecordWriter::RecordWriter(WritableFile* dest, const RecordWriterOptions& options) : dest_(dest), options_(options) { - if (options.compression_type == RecordWriterOptions::ZLIB_COMPRESSION) { + if (IsZlibCompressed(options)) { // We don't have zlib available on all embedded platforms, so fail. #if defined(IS_SLIM_BUILD) LOG(FATAL) << "Zlib compression is unsupported on mobile platforms."; #else // IS_SLIM_BUILD - zlib_output_buffer_.reset(new ZlibOutputBuffer( - dest_, options.zlib_options.input_buffer_size, - options.zlib_options.output_buffer_size, options.zlib_options)); - Status s = zlib_output_buffer_->Init(); + ZlibOutputBuffer* zlib_output_buffer = new ZlibOutputBuffer( + dest, options.zlib_options.input_buffer_size, + options.zlib_options.output_buffer_size, options.zlib_options); + Status s = zlib_output_buffer->Init(); if (!s.ok()) { LOG(FATAL) << "Failed to initialize Zlib inputbuffer. Error: " << s.ToString(); } + dest_ = zlib_output_buffer; #endif // IS_SLIM_BUILD } else if (options.compression_type == RecordWriterOptions::NONE) { // Nothing to do @@ -73,11 +80,12 @@ RecordWriter::RecordWriter(WritableFile* dest, RecordWriter::~RecordWriter() { #if !defined(IS_SLIM_BUILD) - if (zlib_output_buffer_) { - Status s = zlib_output_buffer_->Close(); + if (IsZlibCompressed(options_)) { + Status s = dest_->Close(); if (!s.ok()) { LOG(ERROR) << "Could not finish writing file: " << s; } + delete dest_; } #endif // IS_SLIM_BUILD } @@ -99,20 +107,16 @@ Status RecordWriter::WriteRecord(StringPiece data) { char footer[sizeof(uint32)]; core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size())); -#if !defined(IS_SLIM_BUILD) - if (zlib_output_buffer_) { - TF_RETURN_IF_ERROR( - zlib_output_buffer_->Write(StringPiece(header, sizeof(header)))); - TF_RETURN_IF_ERROR(zlib_output_buffer_->Write(data)); - return zlib_output_buffer_->Write(StringPiece(footer, sizeof(footer))); - } else { -#endif // IS_SLIM_BUILD - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); - TF_RETURN_IF_ERROR(dest_->Append(data)); - return dest_->Append(StringPiece(footer, sizeof(footer))); -#if !defined(IS_SLIM_BUILD) + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + TF_RETURN_IF_ERROR(dest_->Append(data)); + return dest_->Append(StringPiece(footer, sizeof(footer))); +} + +Status RecordWriter::Flush() { + if (IsZlibCompressed(options_)) { + return dest_->Flush(); } -#endif // IS_SLIM_BUILD + return Status::OK(); } } // namespace io diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 5a2373d7570..63f0a7c5d07 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -60,22 +60,11 @@ class RecordWriter { // Flushes any buffered data held by underlying containers of the // RecordWriter to the WritableFile. Does *not* flush the // WritableFile. - Status Flush() { -#if !defined(IS_SLIM_BUILD) - if (zlib_output_buffer_) { - return zlib_output_buffer_->Flush(); - } -#endif // IS_SLIM_BUILD - - return Status::OK(); - } + Status Flush(); private: - WritableFile* const dest_; + WritableFile* dest_; RecordWriterOptions options_; -#if !defined(IS_SLIM_BUILD) - std::unique_ptr<ZlibOutputBuffer> zlib_output_buffer_; -#endif // IS_SLIM_BUILD TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); }; diff --git a/tensorflow/core/lib/io/zlib_buffers_test.cc b/tensorflow/core/lib/io/zlib_buffers_test.cc index f4b08f42009..66ee68a9161 100644 --- a/tensorflow/core/lib/io/zlib_buffers_test.cc +++ b/tensorflow/core/lib/io/zlib_buffers_test.cc @@ -75,7 +75,7 @@ void TestAllCombinations(CompressionOptions input_options, output_options); TF_CHECK_OK(out.Init()); - TF_CHECK_OK(out.Write(StringPiece(data))); + TF_CHECK_OK(out.Append(StringPiece(data))); TF_CHECK_OK(out.Close()); TF_CHECK_OK(file_writer->Flush()); TF_CHECK_OK(file_writer->Close()); @@ -124,7 +124,7 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size, TF_CHECK_OK(out.Init()); for (int i = 0; i < num_writes; i++) { - TF_CHECK_OK(out.Write(StringPiece(data))); + TF_CHECK_OK(out.Append(StringPiece(data))); if (with_flush) { TF_CHECK_OK(out.Flush()); } @@ -176,7 +176,7 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) { output_options); TF_CHECK_OK(out.Init()); - TF_CHECK_OK(out.Write(StringPiece(data))); + TF_CHECK_OK(out.Append(StringPiece(data))); TF_CHECK_OK(out.Close()); TF_CHECK_OK(file_writer->Flush()); TF_CHECK_OK(file_writer->Close()); diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc index cc2d941a643..a65b36b64d4 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.cc +++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc @@ -143,7 +143,7 @@ Status ZlibOutputBuffer::FlushOutputBufferToFile() { return Status::OK(); } -Status ZlibOutputBuffer::Write(StringPiece data) { +Status ZlibOutputBuffer::Append(const StringPiece& data) { // If there is sufficient free space in z_stream_input_ to fit data we // add it there and return. // If there isn't enough space we deflate the existing contents of @@ -197,6 +197,11 @@ Status ZlibOutputBuffer::Flush() { return Status::OK(); } +Status ZlibOutputBuffer::Sync() { + TF_RETURN_IF_ERROR(Flush()); + return file_->Sync(); +} + Status ZlibOutputBuffer::Close() { TF_RETURN_IF_ERROR(DeflateBuffered(true)); TF_RETURN_IF_ERROR(FlushOutputBufferToFile()); diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index a33472cfc53..5cad2e94570 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -33,7 +34,7 @@ namespace io { // (http://www.zlib.net/). // A given instance of an ZlibOutputBuffer is NOT safe for concurrent use // by multiple threads -class ZlibOutputBuffer { +class ZlibOutputBuffer : public WritableFile { public: // Create an ZlibOutputBuffer for `file` with two buffers that cache the // 1. input data to be deflated @@ -61,10 +62,10 @@ class ZlibOutputBuffer { // to file when the buffer is full. // // To immediately write contents to file call `Flush()`. - Status Write(StringPiece data); + Status Append(const StringPiece& data) override; // Deflates any cached input and writes all output to file. - Status Flush(); + Status Flush() override; // Compresses any cached input and writes all output to file. This must be // called before the destructor to avoid any data loss. @@ -74,7 +75,10 @@ class ZlibOutputBuffer { // // After calling this, any further calls to `Write()`, `Flush()` or `Close()` // will fail. - Status Close(); + Status Close() override; + + // Deflates any cached input, writes all output to file and syncs it. + Status Sync() override; private: WritableFile* file_; // Not owned