Make return_state as explicit kwarg in the Conv2DLSTM layer.

It was previously hide in the **kwargs, and we are also missing documentation for it.

The existing test case should already cover the functionality of it.

PiperOrigin-RevId: 317197835
Change-Id: Icfae1e177eeb886b41345078f6b93f282a94df5b
This commit is contained in:
Scott Zhu 2020-06-18 15:50:38 -07:00 committed by TensorFlower Gardener
parent a82b75c82b
commit 0deffad6ac
3 changed files with 28 additions and 19 deletions

View File

@ -753,7 +753,9 @@ class ConvLSTM2D(ConvRNN2D):
the `recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
in the output sequence, or the full sequence. (default False)
return_state: Boolean Whether to return the last state
in addition to the output. (default False)
go_backwards: Boolean (default False).
If True, process the input sequence backwards.
stateful: Boolean (default False). If True, the last state
@ -786,22 +788,27 @@ class ConvLSTM2D(ConvRNN2D):
`(samples, time, rows, cols, channels)`
Output shape:
- If `return_sequences`
- If data_format='channels_first'
5D tensor with shape:
`(samples, time, filters, output_row, output_col)`
- If data_format='channels_last'
5D tensor with shape:
`(samples, time, output_row, output_col, filters)`
- Else
- If data_format ='channels_first'
4D tensor with shape:
`(samples, filters, output_row, output_col)`
- If data_format='channels_last'
4D tensor with shape:
`(samples, output_row, output_col, filters)`
where `o_row` and `o_col` depend on the shape of the filter and
the padding
- If `return_state`: a list of tensors. The first tensor is
the output. The remaining tensors are the last states,
each 4D tensor with shape:
`(samples, filters, new_rows, new_cols)`
if data_format='channels_first'
or 4D tensor with shape:
`(samples, new_rows, new_cols, filters)`
if data_format='channels_last'.
`rows` and `cols` values might have changed due to padding.
- If `return_sequences`: 5D tensor with shape:
`(samples, timesteps, filters, new_rows, new_cols)`
if data_format='channels_first'
or 5D tensor with shape:
`(samples, timesteps, new_rows, new_cols, filters)`
if data_format='channels_last'.
- Else, 4D tensor with shape:
`(samples, filters, new_rows, new_cols)`
if data_format='channels_first'
or 4D tensor with shape:
`(samples, new_rows, new_cols, filters)`
if data_format='channels_last'.
Raises:
ValueError: in case of invalid constructor arguments.
@ -834,6 +841,7 @@ class ConvLSTM2D(ConvRNN2D):
recurrent_constraint=None,
bias_constraint=None,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
dropout=0.,
@ -863,6 +871,7 @@ class ConvLSTM2D(ConvRNN2D):
dtype=kwargs.get('dtype'))
super(ConvLSTM2D, self).__init__(cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
**kwargs)

View File

@ -207,7 +207,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
}
member_method {
name: "add_loss"

View File

@ -207,7 +207,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
}
member_method {
name: "add_loss"