diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index fed88004ee4..052aabf9288 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -192,8 +192,64 @@ def tf_record_iterator(path, options=None): class TFRecordWriter(object): """A class to write records to a TFRecords file. + [TFRecords tutorial](https://www.tensorflow.org/tutorials/load_data/tfrecord) + + TFRecords is a binary format which is optimized for high throughput data + retrieval, generally in conjunction with `tf.data`. `TFRecordWriter` is used + to write serialized examples to a file for later consumption. The key steps + are: + + Ahead of time: + + - [Convert data into a serialized format]( + https://www.tensorflow.org/tutorials/load_data/tfrecord#tfexample) + - [Write the serialized data to one or more files]( + https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecord_files_in_python) + + During training or evaluation: + + - [Read serialized examples into memory]( + https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file) + - [Parse (deserialize) examples]( + https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file) + + A minimal example is given below: + + >>> import tempfile + >>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords") + >>> np.random.seed(0) + + >>> # Write the records to a file. + ... with tf.io.TFRecordWriter(example_path) as file_writer: + ... for _ in range(4): + ... x, y = np.random.random(), np.random.random() + ... + ... record_bytes = tf.train.Example(features=tf.train.Features(feature={ + ... "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])), + ... "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])), + ... })).SerializeToString() + ... file_writer.write(record_bytes) + + >>> # Read the data back out. + >>> def decode_fn(record_bytes): + ... return tf.io.parse_single_example( + ... # Data + ... record_bytes, + ... + ... # Schema + ... {"x": tf.io.FixedLenFeature([], dtype=tf.float32), + ... "y": tf.io.FixedLenFeature([], dtype=tf.float32)} + ... ) + + >>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn): + ... print("x = {x:.4f}, y = {y:.4f}".format(**batch)) + x = 0.5488, y = 0.7152 + x = 0.6028, y = 0.5449 + x = 0.4237, y = 0.6459 + x = 0.4376, y = 0.8918 + This class implements `__enter__` and `__exit__`, and can be used - in `with` blocks like a normal file. + in `with` blocks like a normal file. (See the usage example above.) """ # TODO(josh11b): Support appending?