Make return_state
as explicit kwarg in the Conv2DLSTM layer.
It was previously hide in the **kwargs, and we are also missing documentation for it. The existing test case should already cover the functionality of it. PiperOrigin-RevId: 317197835 Change-Id: Icfae1e177eeb886b41345078f6b93f282a94df5b
This commit is contained in:
parent
a82b75c82b
commit
0deffad6ac
@ -753,7 +753,9 @@ class ConvLSTM2D(ConvRNN2D):
|
|||||||
the `recurrent_kernel` weights matrix.
|
the `recurrent_kernel` weights matrix.
|
||||||
bias_constraint: Constraint function applied to the bias vector.
|
bias_constraint: Constraint function applied to the bias vector.
|
||||||
return_sequences: Boolean. Whether to return the last output
|
return_sequences: Boolean. Whether to return the last output
|
||||||
in the output sequence, or the full sequence.
|
in the output sequence, or the full sequence. (default False)
|
||||||
|
return_state: Boolean Whether to return the last state
|
||||||
|
in addition to the output. (default False)
|
||||||
go_backwards: Boolean (default False).
|
go_backwards: Boolean (default False).
|
||||||
If True, process the input sequence backwards.
|
If True, process the input sequence backwards.
|
||||||
stateful: Boolean (default False). If True, the last state
|
stateful: Boolean (default False). If True, the last state
|
||||||
@ -786,22 +788,27 @@ class ConvLSTM2D(ConvRNN2D):
|
|||||||
`(samples, time, rows, cols, channels)`
|
`(samples, time, rows, cols, channels)`
|
||||||
|
|
||||||
Output shape:
|
Output shape:
|
||||||
- If `return_sequences`
|
- If `return_state`: a list of tensors. The first tensor is
|
||||||
- If data_format='channels_first'
|
the output. The remaining tensors are the last states,
|
||||||
5D tensor with shape:
|
each 4D tensor with shape:
|
||||||
`(samples, time, filters, output_row, output_col)`
|
`(samples, filters, new_rows, new_cols)`
|
||||||
- If data_format='channels_last'
|
if data_format='channels_first'
|
||||||
5D tensor with shape:
|
or 4D tensor with shape:
|
||||||
`(samples, time, output_row, output_col, filters)`
|
`(samples, new_rows, new_cols, filters)`
|
||||||
- Else
|
if data_format='channels_last'.
|
||||||
- If data_format ='channels_first'
|
`rows` and `cols` values might have changed due to padding.
|
||||||
4D tensor with shape:
|
- If `return_sequences`: 5D tensor with shape:
|
||||||
`(samples, filters, output_row, output_col)`
|
`(samples, timesteps, filters, new_rows, new_cols)`
|
||||||
- If data_format='channels_last'
|
if data_format='channels_first'
|
||||||
4D tensor with shape:
|
or 5D tensor with shape:
|
||||||
`(samples, output_row, output_col, filters)`
|
`(samples, timesteps, new_rows, new_cols, filters)`
|
||||||
where `o_row` and `o_col` depend on the shape of the filter and
|
if data_format='channels_last'.
|
||||||
the padding
|
- Else, 4D tensor with shape:
|
||||||
|
`(samples, filters, new_rows, new_cols)`
|
||||||
|
if data_format='channels_first'
|
||||||
|
or 4D tensor with shape:
|
||||||
|
`(samples, new_rows, new_cols, filters)`
|
||||||
|
if data_format='channels_last'.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: in case of invalid constructor arguments.
|
ValueError: in case of invalid constructor arguments.
|
||||||
@ -834,6 +841,7 @@ class ConvLSTM2D(ConvRNN2D):
|
|||||||
recurrent_constraint=None,
|
recurrent_constraint=None,
|
||||||
bias_constraint=None,
|
bias_constraint=None,
|
||||||
return_sequences=False,
|
return_sequences=False,
|
||||||
|
return_state=False,
|
||||||
go_backwards=False,
|
go_backwards=False,
|
||||||
stateful=False,
|
stateful=False,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
@ -863,6 +871,7 @@ class ConvLSTM2D(ConvRNN2D):
|
|||||||
dtype=kwargs.get('dtype'))
|
dtype=kwargs.get('dtype'))
|
||||||
super(ConvLSTM2D, self).__init__(cell,
|
super(ConvLSTM2D, self).__init__(cell,
|
||||||
return_sequences=return_sequences,
|
return_sequences=return_sequences,
|
||||||
|
return_state=return_state,
|
||||||
go_backwards=go_backwards,
|
go_backwards=go_backwards,
|
||||||
stateful=stateful,
|
stateful=stateful,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
@ -207,7 +207,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
|
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
@ -207,7 +207,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
|
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
Loading…
Reference in New Issue
Block a user