Make return_state as explicit kwarg in the Conv2DLSTM layer.

It was previously hide in the **kwargs, and we are also missing documentation for it.

The existing test case should already cover the functionality of it.

PiperOrigin-RevId: 317197835
Change-Id: Icfae1e177eeb886b41345078f6b93f282a94df5b
This commit is contained in:
Scott Zhu 2020-06-18 15:50:38 -07:00 committed by TensorFlower Gardener
parent a82b75c82b
commit 0deffad6ac
3 changed files with 28 additions and 19 deletions

View File

@ -753,7 +753,9 @@ class ConvLSTM2D(ConvRNN2D):
the `recurrent_kernel` weights matrix. the `recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector. bias_constraint: Constraint function applied to the bias vector.
return_sequences: Boolean. Whether to return the last output return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. in the output sequence, or the full sequence. (default False)
return_state: Boolean Whether to return the last state
in addition to the output. (default False)
go_backwards: Boolean (default False). go_backwards: Boolean (default False).
If True, process the input sequence backwards. If True, process the input sequence backwards.
stateful: Boolean (default False). If True, the last state stateful: Boolean (default False). If True, the last state
@ -786,22 +788,27 @@ class ConvLSTM2D(ConvRNN2D):
`(samples, time, rows, cols, channels)` `(samples, time, rows, cols, channels)`
Output shape: Output shape:
- If `return_sequences` - If `return_state`: a list of tensors. The first tensor is
- If data_format='channels_first' the output. The remaining tensors are the last states,
5D tensor with shape: each 4D tensor with shape:
`(samples, time, filters, output_row, output_col)` `(samples, filters, new_rows, new_cols)`
- If data_format='channels_last' if data_format='channels_first'
5D tensor with shape: or 4D tensor with shape:
`(samples, time, output_row, output_col, filters)` `(samples, new_rows, new_cols, filters)`
- Else if data_format='channels_last'.
- If data_format ='channels_first' `rows` and `cols` values might have changed due to padding.
4D tensor with shape: - If `return_sequences`: 5D tensor with shape:
`(samples, filters, output_row, output_col)` `(samples, timesteps, filters, new_rows, new_cols)`
- If data_format='channels_last' if data_format='channels_first'
4D tensor with shape: or 5D tensor with shape:
`(samples, output_row, output_col, filters)` `(samples, timesteps, new_rows, new_cols, filters)`
where `o_row` and `o_col` depend on the shape of the filter and if data_format='channels_last'.
the padding - Else, 4D tensor with shape:
`(samples, filters, new_rows, new_cols)`
if data_format='channels_first'
or 4D tensor with shape:
`(samples, new_rows, new_cols, filters)`
if data_format='channels_last'.
Raises: Raises:
ValueError: in case of invalid constructor arguments. ValueError: in case of invalid constructor arguments.
@ -834,6 +841,7 @@ class ConvLSTM2D(ConvRNN2D):
recurrent_constraint=None, recurrent_constraint=None,
bias_constraint=None, bias_constraint=None,
return_sequences=False, return_sequences=False,
return_state=False,
go_backwards=False, go_backwards=False,
stateful=False, stateful=False,
dropout=0., dropout=0.,
@ -863,6 +871,7 @@ class ConvLSTM2D(ConvRNN2D):
dtype=kwargs.get('dtype')) dtype=kwargs.get('dtype'))
super(ConvLSTM2D, self).__init__(cell, super(ConvLSTM2D, self).__init__(cell,
return_sequences=return_sequences, return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards, go_backwards=go_backwards,
stateful=stateful, stateful=stateful,
**kwargs) **kwargs)

View File

@ -207,7 +207,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], " argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -207,7 +207,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], " argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"