Improves documentation for keras.preprocessing.pad_sequence

PiperOrigin-RevId: 297247197
Change-Id: I4c85e8ba6d4ae43d4c249442ef9c47bb6a5805c6
This commit is contained in:
A. Unique TensorFlower 2020-02-25 17:50:37 -08:00 committed by TensorFlower Gardener
parent 42e6af1ac7
commit 82bcbf0560

View File

@ -24,7 +24,6 @@ from keras_preprocessing import sequence
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
pad_sequences = sequence.pad_sequences
make_sampling_table = sequence.make_sampling_table make_sampling_table = sequence.make_sampling_table
skipgrams = sequence.skipgrams skipgrams = sequence.skipgrams
# TODO(fchollet): consider making `_remove_long_seq` public. # TODO(fchollet): consider making `_remove_long_seq` public.
@ -34,6 +33,7 @@ _remove_long_seq = sequence._remove_long_seq # pylint: disable=protected-access
@keras_export('keras.preprocessing.sequence.TimeseriesGenerator') @keras_export('keras.preprocessing.sequence.TimeseriesGenerator')
class TimeseriesGenerator(sequence.TimeseriesGenerator, data_utils.Sequence): class TimeseriesGenerator(sequence.TimeseriesGenerator, data_utils.Sequence):
"""Utility class for generating batches of temporal data. """Utility class for generating batches of temporal data.
This class takes in a sequence of data-points gathered at This class takes in a sequence of data-points gathered at
equal intervals, along with time series parameters such as equal intervals, along with time series parameters such as
stride, length of history, etc., to produce batches for stride, length of history, etc., to produce batches for
@ -89,7 +89,74 @@ class TimeseriesGenerator(sequence.TimeseriesGenerator, data_utils.Sequence):
pass pass
keras_export('keras.preprocessing.sequence.pad_sequences')(pad_sequences) @keras_export('keras.preprocessing.sequence.pad_sequences')
def pad_sequences(sequences, maxlen=None, dtype='int32',
padding='pre', truncating='pre', value=0.):
"""Pads sequences to the same length.
This function transforms a list (of length `num_samples`)
of sequences (lists of integers)
into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
`num_timesteps` is either the `maxlen` argument if provided,
or the length of the longest sequence in the list.
Sequences that are shorter than `num_timesteps`
are padded with `value` until they are `num_timesteps` long.
Sequences longer than `num_timesteps` are truncated
so that they fit the desired length.
The position where padding or truncation happens is determined by
the arguments `padding` and `truncating`, respectively.
Pre-padding or removing values from the beginning of the sequence is the
default.
>>> sequence = [[1], [2, 3], [4, 5, 6]]
>>> tf.keras.preprocessing.sequence.pad_sequences(sequence)
array([[0, 0, 1],
[0, 2, 3],
[4, 5, 6]], dtype=int32)
>>> tf.keras.preprocessing.sequence.pad_sequences(sequence, value=-1)
array([[-1, -1, 1],
[-1, 2, 3],
[ 4, 5, 6]], dtype=int32)
>>> tf.keras.preprocessing.sequence.pad_sequences(sequence, padding='post')
array([[1, 0, 0],
[2, 3, 0],
[4, 5, 6]], dtype=int32)
>>> tf.keras.preprocessing.sequence.pad_sequences(sequence, maxlen=2)
array([[0, 1],
[2, 3],
[5, 6]], dtype=int32)
Arguments:
sequences: List of sequences (each sequence is a list of integers).
maxlen: Optional Int, maximum length of all sequences. If not provided,
sequences will be padded to the length of the longest individual
sequence.
dtype: (Optional, defaults to int32). Type of the output sequences.
To pad sequences with variable length strings, you can use `object`.
padding: String, 'pre' or 'post' (optional, defaults to 'pre'):
pad either before or after each sequence.
truncating: String, 'pre' or 'post' (optional, defaults to 'pre'):
remove values from sequences larger than
`maxlen`, either at the beginning or at the end of the sequences.
value: Float or String, padding value. (Optional, defaults to 0.)
Returns:
Numpy array with shape `(len(sequences), maxlen)`
Raises:
ValueError: In case of invalid values for `truncating` or `padding`,
or in case of invalid shape for a `sequences` entry.
"""
return sequence.pad_sequences(
sequences, maxlen=maxlen, dtype=dtype,
padding=padding, truncating=truncating, value=value)
keras_export( keras_export(
'keras.preprocessing.sequence.make_sampling_table')(make_sampling_table) 'keras.preprocessing.sequence.make_sampling_table')(make_sampling_table)
keras_export('keras.preprocessing.sequence.skipgrams')(skipgrams) keras_export('keras.preprocessing.sequence.skipgrams')(skipgrams)