Improved seq2seq documentation.

Improved several docstrings.
Added an example in the guide.
Change: 153592069
This commit is contained in:
A. Unique TensorFlower 2017-04-19 07:54:06 -08:00 committed by TensorFlower Gardener
parent dbdfa33723
commit ba7f9a78c8
4 changed files with 65 additions and 12 deletions
tensorflow
contrib/seq2seq/python/ops
docs_src/api_guides/python

View File

@ -52,14 +52,13 @@ class BasicDecoder(decoder.Decoder):
cell: An `RNNCell` instance.
helper: A `Helper` instance.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
The initial state of the RNNCell.
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
Raises:
TypeError: if `cell` is not an instance of `RNNCell`, `helper`
is not an instance of `Helper`, or `output_layer` is not an instance
of `tf.layers.Layer`.
TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
"""
if not isinstance(cell, core_rnn_cell.RNNCell):
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))

View File

@ -44,11 +44,22 @@ _transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-a
@six.add_metaclass(abc.ABCMeta)
class Decoder(object):
"""An RNN Decoder abstract interface object."""
"""An RNN Decoder abstract interface object.
Concepts used by this interface:
- `inputs`: (structure of) tensors and TensorArrays that is passed as input to
the RNNCell composing the decoder, at each time step.
- `state`: (structure of) tensors and TensorArrays that is passed to the
RNNCell instance as the state.
- `finished`: boolean tensor telling whether each sequence in the batch is
finished.
- `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each
time step.
"""
@property
def batch_size(self):
"""The batch size of the inputs returned by `sample`."""
"""The batch size of input values."""
raise NotImplementedError
@property
@ -65,11 +76,14 @@ class Decoder(object):
def initialize(self, name=None):
"""Called before any decoding iterations.
This methods must compute initial input values and initial state.
Args:
name: Name scope for any created operations.
Returns:
`(finished, first_inputs, initial_state)`.
`(finished, initial_inputs, initial_state)`: initial values of
'finished' flags, inputs and state.
"""
raise NotImplementedError
@ -78,13 +92,19 @@ class Decoder(object):
"""Called per step of decoding (but only once for dynamic decoding).
Args:
time: Scalar `int32` tensor.
inputs: Input (possibly nested tuple of) tensor[s] for this time step.
state: State (possibly nested tuple of) tensor[s] from previous time step.
time: Scalar `int32` tensor. Current step number.
inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time
step.
state: RNNCell state (possibly nested tuple of) tensor[s] from previous
time step.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
`(outputs, next_state, next_inputs, finished)`: `outputs` is an instance
of BasicDecoderOutput, `next_state` is a (structure of) state tensors and
TensorArrays, `next_inputs` is the tensor that should be used as input for
the next step, `finished` is a boolean tensor telling whether the sequence
is complete, for each sequence in the batch.
"""
raise NotImplementedError
@ -114,6 +134,8 @@ def dynamic_decode(decoder,
scope=None):
"""Perform dynamic decoding with `decoder`.
Calls initialize() once and step() repeatedly on the Decoder object.
Args:
decoder: A `Decoder` instance.
output_time_major: Python boolean. Default: `False` (batch major). If

View File

@ -57,11 +57,17 @@ def _unstack_ta(inp):
@six.add_metaclass(abc.ABCMeta)
class Helper(object):
"""Helper interface. Helper instances are used by SamplingDecoder."""
"""Interface for implementing sampling in seq2seq decoders.
Helper instances are used by `BasicDecoder`.
"""
@abc.abstractproperty
def batch_size(self):
"""Returns a scalar int32 tensor."""
"""Batch size of tensor returned by `sample`.
Returns a scalar int32 tensor.
"""
raise NotImplementedError("batch_size has not been implemented")
@abc.abstractmethod

View File

@ -91,6 +91,32 @@ not a suggested device partitioning strategy.)
## Dynamic Decoding
Example usage:
``` python
cell = # instance of RNNCell
if mode == "train":
helper = tf.contrib.seq2seq.TrainingHelper(
input=input_vectors,
sequence_length=input_lengths)
elif mode == "infer":
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding=embedding,
start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
end_token=END_SYMBOL)
decoder = tf.contrib.seq2seq.BasicDecoder(
cell=cell,
helper=helper,
initial_state=cell.zero_state(batch_size, tf.float32))
outputs, _ = tf.contrib.seq2seq.dynamic_decode(
decoder=decoder,
output_time_major=False,
impute_finished=True,
maximum_iterations=20)
```
### Decoder base class and functions
* @{tf.contrib.seq2seq.Decoder}
* @{tf.contrib.seq2seq.dynamic_decode}