Add doctests to Normalization, TextVectorization, and Discretization layers.
PiperOrigin-RevId: 313217146 Change-Id: I463399f0cf792f25b82168263e24463c96328e2c
This commit is contained in:
parent
74e98c29aa
commit
e8786b80d7
@ -52,6 +52,16 @@ class Discretization(Layer):
|
|||||||
exclude the right boundary, so `bins=[0., 1., 2.]` generates bins
|
exclude the right boundary, so `bins=[0., 1., 2.]` generates bins
|
||||||
`(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, and `[2., +inf)`.
|
`(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, and `[2., +inf)`.
|
||||||
output_mode: One of 'int', 'binary'. Defaults to 'int'.
|
output_mode: One of 'int', 'binary'. Defaults to 'int'.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Bucketize float values based on provided buckets.
|
||||||
|
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
|
||||||
|
>>> layer = Discretization(bins=[0., 1., 2.])
|
||||||
|
>>> layer(input)
|
||||||
|
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
|
||||||
|
array([[0, 2, 3, 1],
|
||||||
|
[1, 3, 2, 1]], dtype=int32)>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bins, output_mode=INTEGER, **kwargs):
|
def __init__(self, bins, output_mode=INTEGER, **kwargs):
|
||||||
|
@ -55,6 +55,21 @@ class Normalization(CombinerPreprocessingLayer):
|
|||||||
in the specified axis. If set to 'None', the layer will perform scalar
|
in the specified axis. If set to 'None', the layer will perform scalar
|
||||||
normalization (diving the input by a single scalar value). 0 (the batch
|
normalization (diving the input by a single scalar value). 0 (the batch
|
||||||
axis) is not allowed.
|
axis) is not allowed.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Calculate the mean and variance by analyzing the dataset in `adapt`.
|
||||||
|
|
||||||
|
>>> adapt_data = np.array([[1.], [2.], [3.], [4.], [5.]], dtype=np.float32)
|
||||||
|
>>> input_data = np.array([[1.], [2.], [3.]], np.float32)
|
||||||
|
>>> layer = Normalization()
|
||||||
|
>>> layer.adapt(adapt_data)
|
||||||
|
>>> layer(input_data)
|
||||||
|
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
|
||||||
|
array([[-1.4142135 ],
|
||||||
|
[-0.70710677],
|
||||||
|
[ 0. ]], dtype=float32)>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, axis=-1, dtype=None, **kwargs):
|
def __init__(self, axis=-1, dtype=None, **kwargs):
|
||||||
|
@ -146,6 +146,7 @@ class NormalizationTest(keras_parameterized.TestCase,
|
|||||||
self.validate_accumulator_extract(combiner, data, expected)
|
self.validate_accumulator_extract(combiner, data, expected)
|
||||||
self.validate_accumulator_extract_and_restore(combiner, data,
|
self.validate_accumulator_extract_and_restore(combiner, data,
|
||||||
expected)
|
expected)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
"data": np.array([[1], [2], [3], [4], [5]]),
|
"data": np.array([[1], [2], [3], [4], [5]]),
|
||||||
|
@ -157,42 +157,43 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
Example:
|
Example:
|
||||||
This example instantiates a TextVectorization layer that lowercases text,
|
This example instantiates a TextVectorization layer that lowercases text,
|
||||||
splits on whitespace, strips punctuation, and outputs integer vocab indices.
|
splits on whitespace, strips punctuation, and outputs integer vocab indices.
|
||||||
```
|
|
||||||
max_features = 5000 # Maximum vocab size.
|
|
||||||
max_len = 40 # Sequence length to pad the outputs to.
|
|
||||||
|
|
||||||
# Create the layer.
|
>>> text_dataset = tf.data.Dataset.from_tensor_slices(["foo", "bar", "baz"])
|
||||||
vectorize_layer = text_vectorization.TextVectorization(
|
>>> max_features = 5000 # Maximum vocab size.
|
||||||
max_tokens=max_features,
|
>>> max_len = 4 # Sequence length to pad the outputs to.
|
||||||
output_mode='int',
|
>>> embedding_dims = 2
|
||||||
output_sequence_length=max_len)
|
>>>
|
||||||
|
>>> # Create the layer.
|
||||||
|
>>> vectorize_layer = TextVectorization(
|
||||||
|
... max_tokens=max_features,
|
||||||
|
... output_mode='int',
|
||||||
|
... output_sequence_length=max_len)
|
||||||
|
>>>
|
||||||
|
>>> # Now that the vocab layer has been created, call `adapt` on the text-only
|
||||||
|
>>> # dataset to create the vocabulary. You don't have to batch, but for large
|
||||||
|
>>> # datasets this means we're not keeping spare copies of the dataset.
|
||||||
|
>>> vectorize_layer.adapt(text_dataset.batch(64))
|
||||||
|
>>>
|
||||||
|
>>> # Create the model that uses the vectorize text layer
|
||||||
|
>>> model = tf.keras.models.Sequential()
|
||||||
|
>>>
|
||||||
|
>>> # Start by creating an explicit input layer. It needs to have a shape of
|
||||||
|
>>> # (1,) (because we need to guarantee that there is exactly one string
|
||||||
|
>>> # input per batch), and the dtype needs to be 'string'.
|
||||||
|
>>> model.add(tf.keras.Input(shape=(1,), dtype=tf.string))
|
||||||
|
>>>
|
||||||
|
>>> # The first layer in our model is the vectorization layer. After this
|
||||||
|
>>> # layer, we have a tensor of shape (batch_size, max_len) containing vocab
|
||||||
|
>>> # indices.
|
||||||
|
>>> model.add(vectorize_layer)
|
||||||
|
>>>
|
||||||
|
>>> # Now, the model can map strings to integers, and you can add an embedding
|
||||||
|
>>> # layer to map these integers to learned embeddings.
|
||||||
|
>>> input_data = [["foo qux bar"], ["qux baz"]]
|
||||||
|
>>> model.predict(input_data)
|
||||||
|
array([[2, 1, 4, 0],
|
||||||
|
[1, 3, 0, 0]])
|
||||||
|
|
||||||
# Now that the vocab layer has been created, call `adapt` on the text-only
|
|
||||||
# dataset to create the vocabulary. You don't have to batch, but for large
|
|
||||||
# datasets this means we're not keeping spare copies of the dataset in memory.
|
|
||||||
vectorize_layer.adapt(text_dataset.batch(64))
|
|
||||||
|
|
||||||
# Create the model that uses the vectorize text layer
|
|
||||||
model = tf.keras.models.Sequential()
|
|
||||||
|
|
||||||
# Start by creating an explicit input layer. It needs to have a shape of (1,)
|
|
||||||
# (because we need to guarantee that there is exactly one string input per
|
|
||||||
# batch), and the dtype needs to be 'string'.
|
|
||||||
model.add(tf.keras.Input(shape=(1,), dtype=tf.string))
|
|
||||||
|
|
||||||
# The first layer in our model is the vectorization layer. After this layer,
|
|
||||||
# we have a tensor of shape (batch_size, max_len) containing vocab indices.
|
|
||||||
model.add(vectorize_layer)
|
|
||||||
|
|
||||||
# Next, we add a layer to map those vocab indices into a space of
|
|
||||||
# dimensionality 'embedding_dims'. Note that we're using max_features+1 here,
|
|
||||||
# since there's an OOV token that gets added to the vocabulary in
|
|
||||||
# vectorize_layer.
|
|
||||||
model.add(tf.keras.layers.Embedding(max_features+1, embedding_dims))
|
|
||||||
|
|
||||||
# At this point, you have embedded float data representing your tokens, and
|
|
||||||
# can add whatever other layers you need to create your model.
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
# TODO(momernick): Add an examples section to the docstring.
|
# TODO(momernick): Add an examples section to the docstring.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user