Modified the Example code such that it is executable
Github Gist for working code is https://colab.research.google.com/gist/rmothukuru/7ee2a4dabf1743aa85c8943d6f34f0b6/gh_46128.ipynb Fixes #46128 PiperOrigin-RevId: 352429291 Change-Id: Ib392b88a58daf9c33c0cd48885f22716a6ca673f
This commit is contained in:
parent
9cb96ea8fd
commit
2cc955f533
@ -51,24 +51,35 @@ class SequenceFeatures(kfc._BaseFeaturesLayer):
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
# Behavior of some cells or feature columns may depend on whether we are in
|
# Behavior of some cells or feature columns may depend on whether we are in
|
||||||
# training or inference mode, e.g. applying dropout.
|
# training or inference mode, e.g. applying dropout.
|
||||||
training = True
|
training = True
|
||||||
rating = sequence_numeric_column('rating')
|
rating = tf.feature_column.sequence_numeric_column('rating')
|
||||||
watches = sequence_categorical_column_with_identity(
|
watches = tf.feature_column.sequence_categorical_column_with_identity(
|
||||||
'watches', num_buckets=1000)
|
'watches', num_buckets=1000)
|
||||||
watches_embedding = embedding_column(watches, dimension=10)
|
watches_embedding = tf.feature_column.embedding_column(watches,
|
||||||
|
dimension=10)
|
||||||
columns = [rating, watches_embedding]
|
columns = [rating, watches_embedding]
|
||||||
|
|
||||||
sequence_input_layer = SequenceFeatures(columns)
|
features = {
|
||||||
features = tf.io.parse_example(...,
|
'rating': tf.sparse.from_dense([[1.0,1.1, 0, 0, 0],
|
||||||
features=make_parse_example_spec(columns))
|
[2.0,2.1,2.2, 2.3, 2.5]]),
|
||||||
|
'watches': tf.sparse.from_dense([[2, 85, 0, 0, 0],[33,78, 2, 73, 1]])
|
||||||
|
}
|
||||||
|
|
||||||
|
sequence_input_layer = tf.keras.experimental.SequenceFeatures(columns)
|
||||||
sequence_input, sequence_length = sequence_input_layer(
|
sequence_input, sequence_length = sequence_input_layer(
|
||||||
features, training=training)
|
features, training=training)
|
||||||
|
|
||||||
sequence_length_mask = tf.sequence_mask(sequence_length)
|
sequence_length_mask = tf.sequence_mask(sequence_length)
|
||||||
|
|
||||||
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size, training=training)
|
hidden_size = 32
|
||||||
rnn_layer = tf.keras.layers.RNN(rnn_cell, training=training)
|
|
||||||
|
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
|
||||||
|
rnn_layer = tf.keras.layers.RNN(rnn_cell)
|
||||||
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
|
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user