From 0deffad6acbc2f5848022bf8ae360c9adbdf1ef8 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 18 Jun 2020 15:50:38 -0700 Subject: [PATCH] Make `return_state` as explicit kwarg in the Conv2DLSTM layer. It was previously hide in the **kwargs, and we are also missing documentation for it. The existing test case should already cover the functionality of it. PiperOrigin-RevId: 317197835 Change-Id: Icfae1e177eeb886b41345078f6b93f282a94df5b --- .../keras/layers/convolutional_recurrent.py | 43 +++++++++++-------- ...orflow.keras.layers.-conv-l-s-t-m2-d.pbtxt | 2 +- ...orflow.keras.layers.-conv-l-s-t-m2-d.pbtxt | 2 +- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index 19831429b73..6c812204cba 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -753,7 +753,9 @@ class ConvLSTM2D(ConvRNN2D): the `recurrent_kernel` weights matrix. bias_constraint: Constraint function applied to the bias vector. return_sequences: Boolean. Whether to return the last output - in the output sequence, or the full sequence. + in the output sequence, or the full sequence. (default False) + return_state: Boolean Whether to return the last state + in addition to the output. (default False) go_backwards: Boolean (default False). If True, process the input sequence backwards. stateful: Boolean (default False). If True, the last state @@ -786,22 +788,27 @@ class ConvLSTM2D(ConvRNN2D): `(samples, time, rows, cols, channels)` Output shape: - - If `return_sequences` - - If data_format='channels_first' - 5D tensor with shape: - `(samples, time, filters, output_row, output_col)` - - If data_format='channels_last' - 5D tensor with shape: - `(samples, time, output_row, output_col, filters)` - - Else - - If data_format ='channels_first' - 4D tensor with shape: - `(samples, filters, output_row, output_col)` - - If data_format='channels_last' - 4D tensor with shape: - `(samples, output_row, output_col, filters)` - where `o_row` and `o_col` depend on the shape of the filter and - the padding + - If `return_state`: a list of tensors. The first tensor is + the output. The remaining tensors are the last states, + each 4D tensor with shape: + `(samples, filters, new_rows, new_cols)` + if data_format='channels_first' + or 4D tensor with shape: + `(samples, new_rows, new_cols, filters)` + if data_format='channels_last'. + `rows` and `cols` values might have changed due to padding. + - If `return_sequences`: 5D tensor with shape: + `(samples, timesteps, filters, new_rows, new_cols)` + if data_format='channels_first' + or 5D tensor with shape: + `(samples, timesteps, new_rows, new_cols, filters)` + if data_format='channels_last'. + - Else, 4D tensor with shape: + `(samples, filters, new_rows, new_cols)` + if data_format='channels_first' + or 4D tensor with shape: + `(samples, new_rows, new_cols, filters)` + if data_format='channels_last'. Raises: ValueError: in case of invalid constructor arguments. @@ -834,6 +841,7 @@ class ConvLSTM2D(ConvRNN2D): recurrent_constraint=None, bias_constraint=None, return_sequences=False, + return_state=False, go_backwards=False, stateful=False, dropout=0., @@ -863,6 +871,7 @@ class ConvLSTM2D(ConvRNN2D): dtype=kwargs.get('dtype')) super(ConvLSTM2D, self).__init__(cell, return_sequences=return_sequences, + return_state=return_state, go_backwards=go_backwards, stateful=stateful, **kwargs) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index f77d613e354..958d06a0d0f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -207,7 +207,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index f77d613e354..958d06a0d0f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -207,7 +207,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], " } member_method { name: "add_loss"