diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index d19e2b0d5e4..e73a637027c 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -52,14 +52,13 @@ class BasicDecoder(decoder.Decoder): cell: An `RNNCell` instance. helper: A `Helper` instance. initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + The initial state of the RNNCell. output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. Raises: - TypeError: if `cell` is not an instance of `RNNCell`, `helper` - is not an instance of `Helper`, or `output_layer` is not an instance - of `tf.layers.Layer`. + TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. """ if not isinstance(cell, core_rnn_cell.RNNCell): raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index 6338eb152e9..94c92579431 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -44,11 +44,22 @@ _transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-a @six.add_metaclass(abc.ABCMeta) class Decoder(object): - """An RNN Decoder abstract interface object.""" + """An RNN Decoder abstract interface object. + + Concepts used by this interface: + - `inputs`: (structure of) tensors and TensorArrays that is passed as input to + the RNNCell composing the decoder, at each time step. + - `state`: (structure of) tensors and TensorArrays that is passed to the + RNNCell instance as the state. + - `finished`: boolean tensor telling whether each sequence in the batch is + finished. + - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each + time step. + """ @property def batch_size(self): - """The batch size of the inputs returned by `sample`.""" + """The batch size of input values.""" raise NotImplementedError @property @@ -65,11 +76,14 @@ class Decoder(object): def initialize(self, name=None): """Called before any decoding iterations. + This methods must compute initial input values and initial state. + Args: name: Name scope for any created operations. Returns: - `(finished, first_inputs, initial_state)`. + `(finished, initial_inputs, initial_state)`: initial values of + 'finished' flags, inputs and state. """ raise NotImplementedError @@ -78,13 +92,19 @@ class Decoder(object): """Called per step of decoding (but only once for dynamic decoding). Args: - time: Scalar `int32` tensor. - inputs: Input (possibly nested tuple of) tensor[s] for this time step. - state: State (possibly nested tuple of) tensor[s] from previous time step. + time: Scalar `int32` tensor. Current step number. + inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time + step. + state: RNNCell state (possibly nested tuple of) tensor[s] from previous + time step. name: Name scope for any created operations. Returns: - `(outputs, next_state, next_inputs, finished)`. + `(outputs, next_state, next_inputs, finished)`: `outputs` is an instance + of BasicDecoderOutput, `next_state` is a (structure of) state tensors and + TensorArrays, `next_inputs` is the tensor that should be used as input for + the next step, `finished` is a boolean tensor telling whether the sequence + is complete, for each sequence in the batch. """ raise NotImplementedError @@ -114,6 +134,8 @@ def dynamic_decode(decoder, scope=None): """Perform dynamic decoding with `decoder`. + Calls initialize() once and step() repeatedly on the Decoder object. + Args: decoder: A `Decoder` instance. output_time_major: Python boolean. Default: `False` (batch major). If diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index e43d155608a..e2d56063a29 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -57,11 +57,17 @@ def _unstack_ta(inp): @six.add_metaclass(abc.ABCMeta) class Helper(object): - """Helper interface. Helper instances are used by SamplingDecoder.""" + """Interface for implementing sampling in seq2seq decoders. + + Helper instances are used by `BasicDecoder`. + """ @abc.abstractproperty def batch_size(self): - """Returns a scalar int32 tensor.""" + """Batch size of tensor returned by `sample`. + + Returns a scalar int32 tensor. + """ raise NotImplementedError("batch_size has not been implemented") @abc.abstractmethod diff --git a/tensorflow/docs_src/api_guides/python/contrib.seq2seq.md b/tensorflow/docs_src/api_guides/python/contrib.seq2seq.md index 223bf4a0a31..2522e50c266 100644 --- a/tensorflow/docs_src/api_guides/python/contrib.seq2seq.md +++ b/tensorflow/docs_src/api_guides/python/contrib.seq2seq.md @@ -91,6 +91,32 @@ not a suggested device partitioning strategy.) ## Dynamic Decoding +Example usage: + +``` python +cell = # instance of RNNCell + +if mode == "train": + helper = tf.contrib.seq2seq.TrainingHelper( + input=input_vectors, + sequence_length=input_lengths) +elif mode == "infer": + helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( + embedding=embedding, + start_tokens=tf.tile([GO_SYMBOL], [batch_size]), + end_token=END_SYMBOL) + +decoder = tf.contrib.seq2seq.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state(batch_size, tf.float32)) +outputs, _ = tf.contrib.seq2seq.dynamic_decode( + decoder=decoder, + output_time_major=False, + impute_finished=True, + maximum_iterations=20) +``` + ### Decoder base class and functions * @{tf.contrib.seq2seq.Decoder} * @{tf.contrib.seq2seq.dynamic_decode}