116 lines
4.4 KiB
Python
116 lines
4.4 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Python wrappers for tf.data writers."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.data.util import convert
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.ops import gen_experimental_dataset_ops
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export("data.experimental.TFRecordWriter")
|
|
class TFRecordWriter(object):
|
|
"""Writes a dataset to a TFRecord file.
|
|
|
|
The elements of the dataset must be scalar strings. To serialize dataset
|
|
elements as strings, you can use the `tf.io.serialize_tensor` function.
|
|
|
|
```python
|
|
dataset = tf.data.Dataset.range(3)
|
|
dataset = dataset.map(tf.io.serialize_tensor)
|
|
writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord")
|
|
writer.write(dataset)
|
|
```
|
|
|
|
To read back the elements, use `TFRecordDataset`.
|
|
|
|
```python
|
|
dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord")
|
|
dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64))
|
|
```
|
|
|
|
To shard a `dataset` across multiple TFRecord files:
|
|
|
|
```python
|
|
dataset = ... # dataset to be written
|
|
|
|
def reduce_func(key, dataset):
|
|
filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)])
|
|
writer = tf.data.experimental.TFRecordWriter(filename)
|
|
writer.write(dataset.map(lambda _, x: x))
|
|
return tf.data.Dataset.from_tensors(filename)
|
|
|
|
dataset = dataset.enumerate()
|
|
dataset = dataset.apply(tf.data.experimental.group_by_window(
|
|
lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max
|
|
))
|
|
```
|
|
"""
|
|
|
|
def __init__(self, filename, compression_type=None):
|
|
"""Initializes a `TFRecordWriter`.
|
|
|
|
Args:
|
|
filename: a string path indicating where to write the TFRecord data.
|
|
compression_type: (Optional.) a string indicating what type of compression
|
|
to use when writing the file. See `tf.io.TFRecordCompressionType` for
|
|
what types of compression are available. Defaults to `None`.
|
|
"""
|
|
self._filename = ops.convert_to_tensor(
|
|
filename, dtypes.string, name="filename")
|
|
self._compression_type = convert.optional_param_to_tensor(
|
|
"compression_type",
|
|
compression_type,
|
|
argument_default="",
|
|
argument_dtype=dtypes.string)
|
|
|
|
def write(self, dataset):
|
|
"""Writes a dataset to a TFRecord file.
|
|
|
|
An operation that writes the content of the specified dataset to the file
|
|
specified in the constructor.
|
|
|
|
If the file exists, it will be overwritten.
|
|
|
|
Args:
|
|
dataset: a `tf.data.Dataset` whose elements are to be written to a file
|
|
|
|
Returns:
|
|
In graph mode, this returns an operation which when executed performs the
|
|
write. In eager mode, the write is performed by the method itself and
|
|
there is no return value.
|
|
|
|
Raises
|
|
TypeError: if `dataset` is not a `tf.data.Dataset`.
|
|
TypeError: if the elements produced by the dataset are not scalar strings.
|
|
"""
|
|
if not isinstance(dataset, dataset_ops.DatasetV2):
|
|
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
|
|
if not dataset_ops.get_structure(dataset).is_compatible_with(
|
|
tensor_spec.TensorSpec([], dtypes.string)):
|
|
raise TypeError(
|
|
"`dataset` must produce scalar `DT_STRING` tensors whereas it "
|
|
"produces shape {0} and types {1}".format(
|
|
dataset_ops.get_legacy_output_shapes(dataset),
|
|
dataset_ops.get_legacy_output_types(dataset)))
|
|
return gen_experimental_dataset_ops.dataset_to_tf_record(
|
|
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
|