Improved seq2seq documentation.
Improved several docstrings. Added an example in the guide. Change: 153592069
This commit is contained in:
parent
dbdfa33723
commit
ba7f9a78c8
tensorflow
contrib/seq2seq/python/ops
docs_src/api_guides/python
@ -52,14 +52,13 @@ class BasicDecoder(decoder.Decoder):
|
||||
cell: An `RNNCell` instance.
|
||||
helper: A `Helper` instance.
|
||||
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
||||
The initial state of the RNNCell.
|
||||
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
|
||||
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
|
||||
to storing the result or sampling.
|
||||
|
||||
Raises:
|
||||
TypeError: if `cell` is not an instance of `RNNCell`, `helper`
|
||||
is not an instance of `Helper`, or `output_layer` is not an instance
|
||||
of `tf.layers.Layer`.
|
||||
TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
|
||||
"""
|
||||
if not isinstance(cell, core_rnn_cell.RNNCell):
|
||||
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
|
||||
|
@ -44,11 +44,22 @@ _transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-a
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Decoder(object):
|
||||
"""An RNN Decoder abstract interface object."""
|
||||
"""An RNN Decoder abstract interface object.
|
||||
|
||||
Concepts used by this interface:
|
||||
- `inputs`: (structure of) tensors and TensorArrays that is passed as input to
|
||||
the RNNCell composing the decoder, at each time step.
|
||||
- `state`: (structure of) tensors and TensorArrays that is passed to the
|
||||
RNNCell instance as the state.
|
||||
- `finished`: boolean tensor telling whether each sequence in the batch is
|
||||
finished.
|
||||
- `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each
|
||||
time step.
|
||||
"""
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
"""The batch size of the inputs returned by `sample`."""
|
||||
"""The batch size of input values."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@ -65,11 +76,14 @@ class Decoder(object):
|
||||
def initialize(self, name=None):
|
||||
"""Called before any decoding iterations.
|
||||
|
||||
This methods must compute initial input values and initial state.
|
||||
|
||||
Args:
|
||||
name: Name scope for any created operations.
|
||||
|
||||
Returns:
|
||||
`(finished, first_inputs, initial_state)`.
|
||||
`(finished, initial_inputs, initial_state)`: initial values of
|
||||
'finished' flags, inputs and state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -78,13 +92,19 @@ class Decoder(object):
|
||||
"""Called per step of decoding (but only once for dynamic decoding).
|
||||
|
||||
Args:
|
||||
time: Scalar `int32` tensor.
|
||||
inputs: Input (possibly nested tuple of) tensor[s] for this time step.
|
||||
state: State (possibly nested tuple of) tensor[s] from previous time step.
|
||||
time: Scalar `int32` tensor. Current step number.
|
||||
inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time
|
||||
step.
|
||||
state: RNNCell state (possibly nested tuple of) tensor[s] from previous
|
||||
time step.
|
||||
name: Name scope for any created operations.
|
||||
|
||||
Returns:
|
||||
`(outputs, next_state, next_inputs, finished)`.
|
||||
`(outputs, next_state, next_inputs, finished)`: `outputs` is an instance
|
||||
of BasicDecoderOutput, `next_state` is a (structure of) state tensors and
|
||||
TensorArrays, `next_inputs` is the tensor that should be used as input for
|
||||
the next step, `finished` is a boolean tensor telling whether the sequence
|
||||
is complete, for each sequence in the batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -114,6 +134,8 @@ def dynamic_decode(decoder,
|
||||
scope=None):
|
||||
"""Perform dynamic decoding with `decoder`.
|
||||
|
||||
Calls initialize() once and step() repeatedly on the Decoder object.
|
||||
|
||||
Args:
|
||||
decoder: A `Decoder` instance.
|
||||
output_time_major: Python boolean. Default: `False` (batch major). If
|
||||
|
@ -57,11 +57,17 @@ def _unstack_ta(inp):
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Helper(object):
|
||||
"""Helper interface. Helper instances are used by SamplingDecoder."""
|
||||
"""Interface for implementing sampling in seq2seq decoders.
|
||||
|
||||
Helper instances are used by `BasicDecoder`.
|
||||
"""
|
||||
|
||||
@abc.abstractproperty
|
||||
def batch_size(self):
|
||||
"""Returns a scalar int32 tensor."""
|
||||
"""Batch size of tensor returned by `sample`.
|
||||
|
||||
Returns a scalar int32 tensor.
|
||||
"""
|
||||
raise NotImplementedError("batch_size has not been implemented")
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -91,6 +91,32 @@ not a suggested device partitioning strategy.)
|
||||
|
||||
## Dynamic Decoding
|
||||
|
||||
Example usage:
|
||||
|
||||
``` python
|
||||
cell = # instance of RNNCell
|
||||
|
||||
if mode == "train":
|
||||
helper = tf.contrib.seq2seq.TrainingHelper(
|
||||
input=input_vectors,
|
||||
sequence_length=input_lengths)
|
||||
elif mode == "infer":
|
||||
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
|
||||
embedding=embedding,
|
||||
start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
|
||||
end_token=END_SYMBOL)
|
||||
|
||||
decoder = tf.contrib.seq2seq.BasicDecoder(
|
||||
cell=cell,
|
||||
helper=helper,
|
||||
initial_state=cell.zero_state(batch_size, tf.float32))
|
||||
outputs, _ = tf.contrib.seq2seq.dynamic_decode(
|
||||
decoder=decoder,
|
||||
output_time_major=False,
|
||||
impute_finished=True,
|
||||
maximum_iterations=20)
|
||||
```
|
||||
|
||||
### Decoder base class and functions
|
||||
* @{tf.contrib.seq2seq.Decoder}
|
||||
* @{tf.contrib.seq2seq.dynamic_decode}
|
||||
|
Loading…
Reference in New Issue
Block a user