Add compression options to Python's TFRecordOptions
Plumb these through to RecordWriterOptions PiperOrigin-RevId: 211894734
This commit is contained in:
parent
b096c49471
commit
e001f3ad84
@ -41,7 +41,7 @@ class RecordWriterOptions {
|
||||
|
||||
// Options specific to zlib compression.
|
||||
#if !defined(IS_SLIM_BUILD)
|
||||
ZlibCompressionOptions zlib_options;
|
||||
tensorflow::io::ZlibCompressionOptions zlib_options;
|
||||
#endif // IS_SLIM_BUILD
|
||||
};
|
||||
|
||||
|
@ -30,6 +30,8 @@ namespace io {
|
||||
|
||||
PyRecordReader::PyRecordReader() {}
|
||||
|
||||
// NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
|
||||
// RecordReaderOptions, if this changes the API can be updated at that time.
|
||||
PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset,
|
||||
const string& compression_type_string,
|
||||
TF_Status* out_status) {
|
||||
|
@ -28,7 +28,7 @@ namespace io {
|
||||
PyRecordWriter::PyRecordWriter() {}
|
||||
|
||||
PyRecordWriter* PyRecordWriter::New(const string& filename,
|
||||
const string& compression_type_string,
|
||||
const io::RecordWriterOptions& options,
|
||||
TF_Status* out_status) {
|
||||
std::unique_ptr<WritableFile> file;
|
||||
Status s = Env::Default()->NewWritableFile(filename, &file);
|
||||
@ -38,10 +38,6 @@ PyRecordWriter* PyRecordWriter::New(const string& filename,
|
||||
}
|
||||
PyRecordWriter* writer = new PyRecordWriter;
|
||||
writer->file_ = std::move(file);
|
||||
|
||||
RecordWriterOptions options =
|
||||
RecordWriterOptions::CreateRecordWriterOptions(compression_type_string);
|
||||
|
||||
writer->writer_.reset(new RecordWriter(writer->file_.get(), options));
|
||||
return writer;
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -36,10 +37,8 @@ class RecordWriter;
|
||||
// by multiple threads.
|
||||
class PyRecordWriter {
|
||||
public:
|
||||
// TODO(vrv): make this take a shared proto to configure
|
||||
// the compression options.
|
||||
static PyRecordWriter* New(const string& filename,
|
||||
const string& compression_type_string,
|
||||
const io::RecordWriterOptions& compression_options,
|
||||
TF_Status* out_status);
|
||||
~PyRecordWriter();
|
||||
|
||||
|
@ -18,6 +18,11 @@ limitations under the License.
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
|
||||
// Define int8_t explicitly instead of including "stdint.i", since "stdint.h"
|
||||
// and "stdint.i" disagree on the definition of int64_t.
|
||||
typedef signed char int8;
|
||||
%{ typedef signed char int8; %}
|
||||
|
||||
%feature("except") tensorflow::io::PyRecordWriter::New {
|
||||
// Let other threads run while we write
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
@ -26,6 +31,7 @@ limitations under the License.
|
||||
}
|
||||
|
||||
%newobject tensorflow::io::PyRecordWriter::New;
|
||||
%newobject tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
|
||||
|
||||
%feature("except") tensorflow::io::PyRecordWriter::WriteRecord {
|
||||
// Let other threads run while we write
|
||||
@ -35,6 +41,8 @@ limitations under the License.
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/python/lib/io/py_record_writer.h"
|
||||
%}
|
||||
|
||||
@ -48,7 +56,21 @@ limitations under the License.
|
||||
%unignore tensorflow::io::PyRecordWriter::Flush;
|
||||
%unignore tensorflow::io::PyRecordWriter::Close;
|
||||
%unignore tensorflow::io::PyRecordWriter::New;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::flush_mode;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::input_buffer_size;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::output_buffer_size;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::window_bits;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::compression_level;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::compression_method;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::mem_level;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::compression_strategy;
|
||||
%unignore tensorflow::io::RecordWriterOptions;
|
||||
%unignore tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
|
||||
%unignore tensorflow::io::RecordWriterOptions::zlib_options;
|
||||
|
||||
%include "tensorflow/core/lib/io/record_writer.h"
|
||||
%include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
%include "tensorflow/python/lib/io/py_record_writer.h"
|
||||
|
||||
%unignoreall
|
||||
|
@ -33,8 +33,6 @@ class TFRecordCompressionType(object):
|
||||
GZIP = 2
|
||||
|
||||
|
||||
# NOTE(vrv): This will eventually be converted into a proto. to match
|
||||
# the interface used by the C++ RecordWriter.
|
||||
@tf_export("python_io.TFRecordOptions")
|
||||
class TFRecordOptions(object):
|
||||
"""Options used for manipulating TFRecord files."""
|
||||
@ -44,14 +42,105 @@ class TFRecordOptions(object):
|
||||
TFRecordCompressionType.NONE: ""
|
||||
}
|
||||
|
||||
def __init__(self, compression_type):
|
||||
def __init__(self,
|
||||
compression_type=None,
|
||||
flush_mode=None,
|
||||
input_buffer_size=None,
|
||||
output_buffer_size=None,
|
||||
window_bits=None,
|
||||
compression_level=None,
|
||||
compression_method=None,
|
||||
mem_level=None,
|
||||
compression_strategy=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Creates a `TFRecordOptions` instance.
|
||||
|
||||
Options only effect TFRecordWriter when compression_type is not `None`.
|
||||
Documentation, details, and defaults can be found in
|
||||
[`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
|
||||
and in the [zlib manual](http://www.zlib.net/manual.html).
|
||||
Leaving an option as `None` allows C++ to set a reasonable default.
|
||||
|
||||
Args:
|
||||
compression_type: `TFRecordCompressionType` or `None`.
|
||||
flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
|
||||
input_buffer_size: int or `None`.
|
||||
output_buffer_size: int or `None`.
|
||||
window_bits: int or `None`.
|
||||
compression_level: 0 to 9, or `None`.
|
||||
compression_method: compression method or `None`.
|
||||
mem_level: 1 to 9, or `None`.
|
||||
compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
|
||||
|
||||
Returns:
|
||||
A `TFRecordOptions` object.
|
||||
|
||||
Raises:
|
||||
ValueError: If compression_type is invalid.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
# Check compression_type is valid, but for backwards compatibility don't
|
||||
# immediately convert to a string.
|
||||
self.get_compression_type_string(compression_type)
|
||||
self.compression_type = compression_type
|
||||
self.flush_mode = flush_mode
|
||||
self.input_buffer_size = input_buffer_size
|
||||
self.output_buffer_size = output_buffer_size
|
||||
self.window_bits = window_bits
|
||||
self.compression_level = compression_level
|
||||
self.compression_method = compression_method
|
||||
self.mem_level = mem_level
|
||||
self.compression_strategy = compression_strategy
|
||||
|
||||
@classmethod
|
||||
def get_compression_type_string(cls, options):
|
||||
"""Convert various option types to a unified string.
|
||||
|
||||
Args:
|
||||
options: `TFRecordOption`, `TFRecordCompressionType`, or string.
|
||||
|
||||
Returns:
|
||||
Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
|
||||
|
||||
Raises:
|
||||
ValueError: If compression_type is invalid.
|
||||
"""
|
||||
if not options:
|
||||
return ""
|
||||
return cls.compression_type_map[options.compression_type]
|
||||
elif isinstance(options, TFRecordOptions):
|
||||
return cls.get_compression_type_string(options.compression_type)
|
||||
elif isinstance(options, TFRecordCompressionType):
|
||||
return cls.compression_type_map[options]
|
||||
elif options in TFRecordOptions.compression_type_map:
|
||||
return cls.compression_type_map[options]
|
||||
elif options in TFRecordOptions.compression_type_map.values():
|
||||
return options
|
||||
else:
|
||||
raise ValueError('Not a valid compression_type: "{}"'.format(options))
|
||||
|
||||
def _as_record_writer_options(self):
|
||||
"""Convert to RecordWriterOptions for use with PyRecordWriter."""
|
||||
options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions(
|
||||
compat.as_bytes(
|
||||
self.get_compression_type_string(self.compression_type)))
|
||||
|
||||
if self.flush_mode is not None:
|
||||
options.zlib_options.flush_mode = self.flush_mode
|
||||
if self.input_buffer_size is not None:
|
||||
options.zlib_options.input_buffer_size = self.input_buffer_size
|
||||
if self.output_buffer_size is not None:
|
||||
options.zlib_options.output_buffer_size = self.output_buffer_size
|
||||
if self.window_bits is not None:
|
||||
options.zlib_options.window_bits = self.window_bits
|
||||
if self.compression_level is not None:
|
||||
options.zlib_options.compression_level = self.compression_level
|
||||
if self.compression_method is not None:
|
||||
options.zlib_options.compression_method = self.compression_method
|
||||
if self.mem_level is not None:
|
||||
options.zlib_options.mem_level = self.mem_level
|
||||
if self.compression_strategy is not None:
|
||||
options.zlib_options.compression_strategy = self.compression_strategy
|
||||
return options
|
||||
|
||||
|
||||
@tf_export("python_io.tf_record_iterator")
|
||||
@ -100,16 +189,21 @@ class TFRecordWriter(object):
|
||||
|
||||
Args:
|
||||
path: The path to the TFRecords file.
|
||||
options: (optional) A TFRecordOptions object.
|
||||
options: (optional) String specifying compression type,
|
||||
`TFRecordCompressionType`, or `TFRecordOptions` object.
|
||||
|
||||
Raises:
|
||||
IOError: If `path` cannot be opened for writing.
|
||||
ValueError: If valid compression_type can't be determined from `options`.
|
||||
"""
|
||||
compression_type = TFRecordOptions.get_compression_type_string(options)
|
||||
if not isinstance(options, TFRecordOptions):
|
||||
options = TFRecordOptions(compression_type=options)
|
||||
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
# pylint: disable=protected-access
|
||||
self._writer = pywrap_tensorflow.PyRecordWriter_New(
|
||||
compat.as_bytes(path), compat.as_bytes(compression_type), status)
|
||||
compat.as_bytes(path), options._as_record_writer_options(), status)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter a `with` block."""
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import zlib
|
||||
|
||||
import six
|
||||
@ -131,9 +133,6 @@ class TFCompressionTestCase(test.TestCase):
|
||||
|
||||
class TFRecordWriterTest(TFCompressionTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TFRecordWriterTest, self).setUp()
|
||||
|
||||
def _AssertFilesEqual(self, a, b, equal):
|
||||
for an, bn in zip(a, b):
|
||||
with open(an, "rb") as af, open(bn, "rb") as bf:
|
||||
@ -142,6 +141,37 @@ class TFRecordWriterTest(TFCompressionTestCase):
|
||||
else:
|
||||
self.assertNotEqual(af.read(), bf.read())
|
||||
|
||||
def _CompressionSizeDelta(self, records, options_a, options_b):
|
||||
"""Validate compression with options_a and options_b and return size delta.
|
||||
|
||||
Compress records with options_a and options_b. Uncompress both compressed
|
||||
files and assert that the contents match the original records. Finally
|
||||
calculate how much smaller the file compressed with options_a was than the
|
||||
file compressed with options_b.
|
||||
|
||||
Args:
|
||||
records: The records to compress
|
||||
options_a: First set of options to compress with, the baseline for size.
|
||||
options_b: Second set of options to compress with.
|
||||
|
||||
Returns:
|
||||
The difference in file size when using options_a vs options_b. A positive
|
||||
value means options_a was a better compression than options_b. A negative
|
||||
value means options_b had better compression than options_a.
|
||||
|
||||
"""
|
||||
|
||||
fn_a = self._WriteRecordsToFile(records, "tfrecord_a", options=options_a)
|
||||
test_a = list(tf_record.tf_record_iterator(fn_a, options=options_a))
|
||||
self.assertEqual(records, test_a, options_a)
|
||||
|
||||
fn_b = self._WriteRecordsToFile(records, "tfrecord_b", options=options_b)
|
||||
test_b = list(tf_record.tf_record_iterator(fn_b, options=options_b))
|
||||
self.assertEqual(records, test_b, options_b)
|
||||
|
||||
# Negative number => better compression.
|
||||
return os.path.getsize(fn_a) - os.path.getsize(fn_b)
|
||||
|
||||
def testWriteReadZLibFiles(self):
|
||||
# Write uncompressed then compress manually.
|
||||
options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
|
||||
@ -188,6 +218,76 @@ class TFRecordWriterTest(TFCompressionTestCase):
|
||||
]
|
||||
self._AssertFilesEqual(uncompressed_files, files, True)
|
||||
|
||||
def testNoCompressionType(self):
|
||||
self.assertEqual(
|
||||
"",
|
||||
tf_record.TFRecordOptions.get_compression_type_string(
|
||||
tf_record.TFRecordOptions()))
|
||||
|
||||
self.assertEqual(
|
||||
"",
|
||||
tf_record.TFRecordOptions.get_compression_type_string(
|
||||
tf_record.TFRecordOptions("")))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tf_record.TFRecordOptions(5)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tf_record.TFRecordOptions("BZ2")
|
||||
|
||||
def testZlibCompressionType(self):
|
||||
zlib_t = tf_record.TFRecordCompressionType.ZLIB
|
||||
|
||||
self.assertEqual(
|
||||
"ZLIB",
|
||||
tf_record.TFRecordOptions.get_compression_type_string(
|
||||
tf_record.TFRecordOptions("ZLIB")))
|
||||
|
||||
self.assertEqual(
|
||||
"ZLIB",
|
||||
tf_record.TFRecordOptions.get_compression_type_string(
|
||||
tf_record.TFRecordOptions(zlib_t)))
|
||||
|
||||
self.assertEqual(
|
||||
"ZLIB",
|
||||
tf_record.TFRecordOptions.get_compression_type_string(
|
||||
tf_record.TFRecordOptions(tf_record.TFRecordOptions(zlib_t))))
|
||||
|
||||
def testCompressionOptions(self):
|
||||
# Create record with mix of random and repeated data to test compression on.
|
||||
rnd = random.Random(123)
|
||||
random_record = compat.as_bytes(
|
||||
"".join(rnd.choice(string.digits) for _ in range(10000)))
|
||||
repeated_record = compat.as_bytes(_TEXT)
|
||||
for _ in range(10000):
|
||||
start_i = rnd.randint(0, len(_TEXT))
|
||||
length = rnd.randint(10, 200)
|
||||
repeated_record += _TEXT[start_i:start_i + length]
|
||||
records = [random_record, repeated_record, random_record]
|
||||
|
||||
tests = [
|
||||
("compression_level", 2, -1), # Lower compression is worse.
|
||||
("compression_level", 6, 0), # Default compression_level is equal.
|
||||
("flush_mode", zlib.Z_FULL_FLUSH, 1), # A few less bytes.
|
||||
("flush_mode", zlib.Z_NO_FLUSH, 0), # NO_FLUSH is the default.
|
||||
("input_buffer_size", 4096, 0), # Increases time not size.
|
||||
("output_buffer_size", 4096, 0), # Increases time not size.
|
||||
("window_bits", 8, -1), # Smaller than default window increases size.
|
||||
("compression_strategy", zlib.Z_HUFFMAN_ONLY, -1), # Worse.
|
||||
("compression_strategy", zlib.Z_FILTERED, -1), # Worse.
|
||||
]
|
||||
|
||||
compression_type = tf_record.TFRecordCompressionType.ZLIB
|
||||
options_a = tf_record.TFRecordOptions(compression_type)
|
||||
for prop, value, delta_sign in tests:
|
||||
options_b = tf_record.TFRecordOptions(
|
||||
compression_type=compression_type, **{prop: value})
|
||||
delta = self._CompressionSizeDelta(records, options_a, options_b)
|
||||
self.assertTrue(
|
||||
delta == 0 if delta_sign == 0 else delta // delta_sign > 0,
|
||||
"Setting {} = {}, file was {} smaller didn't match sign of {}".format(
|
||||
prop, value, delta, delta_sign))
|
||||
|
||||
|
||||
class TFRecordWriterZlibTest(TFCompressionTestCase):
|
||||
|
||||
@ -318,6 +418,7 @@ class TFRecordIteratorTest(TFCompressionTestCase):
|
||||
for _ in tf_record.tf_record_iterator(fn_truncated):
|
||||
pass
|
||||
|
||||
|
||||
class TFRecordWriterCloseAndFlushTests(test.TestCase):
|
||||
|
||||
def setUp(self, compression_type=TFRecordCompressionType.NONE):
|
||||
|
@ -8,7 +8,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_compression_type_string"
|
||||
|
@ -8,7 +8,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_compression_type_string"
|
||||
|
Loading…
x
Reference in New Issue
Block a user