From 8569d12747c13ac16e864db1482b6754904371c8 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Mon, 30 Sep 2019 10:14:25 -0700 Subject: [PATCH] Update docs for tf.data.experimental.TFRecordWriter PiperOrigin-RevId: 272008952 --- .../python/data/experimental/ops/writers.py | 40 ++++++++++++++++--- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py index 0d1785c7ee3..9c88e2043d7 100644 --- a/tensorflow/python/data/experimental/ops/writers.py +++ b/tensorflow/python/data/experimental/ops/writers.py @@ -28,16 +28,25 @@ from tensorflow.python.util.tf_export import tf_export @tf_export("data.experimental.TFRecordWriter") class TFRecordWriter(object): - """Writes data to a TFRecord file. + """Writes a dataset to a TFRecord file. - To write a `dataset` to a single TFRecord file: + The elements of the dataset must be scalar strings. To serialize dataset + elements as strings, you can use the `tf.io.serialize_tensor` function. ```python - dataset = ... # dataset to be written - writer = tf.data.experimental.TFRecordWriter(PATH) + dataset = tf.data.Dataset.range(3) + dataset = dataset.map(tf.io.serialize_tensor) + writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord") writer.write(dataset) ``` + To read back the elements, use `TFRecordDataset`. + + ```python + dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord") + dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64)) + ``` + To shard a `dataset` across multiple TFRecord files: ```python @@ -57,6 +66,14 @@ class TFRecordWriter(object): """ def __init__(self, filename, compression_type=None): + """Initializes a `TFRecordWriter`. + + Args: + filename: a string path indicating where to write the TFRecord data. + compression_type: (Optional.) a string indicating what type of compression + to use when writing the file. See `tf.io.TFRecordCompressionType` for + what types of compression are available. Defaults to `None`. + """ self._filename = ops.convert_to_tensor( filename, dtypes.string, name="filename") self._compression_type = convert.optional_param_to_tensor( @@ -66,13 +83,24 @@ class TFRecordWriter(object): argument_dtype=dtypes.string) def write(self, dataset): - """Returns a `tf.Operation` to write a dataset to a file. + """Writes a dataset to a TFRecord file. + + An operation that writes the content of the specified dataset to the file + specified in the constructor. + + If the file exists, it will be overwritten. Args: dataset: a `tf.data.Dataset` whose elements are to be written to a file Returns: - A `tf.Operation` that, when run, writes contents of `dataset` to a file. + In graph mode, this returns an operation which when executed performs the + write. In eager mode, the write is performed by the method itself and + there is no return value. + + Raises + TypeError: if `dataset` is not a `tf.data.Dataset`. + TypeError: if the elements produced by the dataset are not scalar strings. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError("`dataset` must be a `tf.data.Dataset` object.")