Update ConvLSTMCell to accept shape as list or tuple
PiperOrigin-RevId: 228852641
This commit is contained in:
parent
c2ab6dc7a1
commit
bb87d0021f
@ -2071,7 +2071,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
|
|||||||
conv_ndims: Convolution dimensionality (1, 2 or 3).
|
conv_ndims: Convolution dimensionality (1, 2 or 3).
|
||||||
input_shape: Shape of the input as int tuple, excluding the batch size.
|
input_shape: Shape of the input as int tuple, excluding the batch size.
|
||||||
output_channels: int, number of output channels of the conv LSTM.
|
output_channels: int, number of output channels of the conv LSTM.
|
||||||
kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3).
|
kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3).
|
||||||
use_bias: (bool) Use bias in convolutions.
|
use_bias: (bool) Use bias in convolutions.
|
||||||
skip_connection: If set to `True`, concatenate the input to the
|
skip_connection: If set to `True`, concatenate the input to the
|
||||||
output of the conv LSTM. Default: `False`.
|
output of the conv LSTM. Default: `False`.
|
||||||
@ -2092,7 +2092,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
|
|||||||
self._conv_ndims = conv_ndims
|
self._conv_ndims = conv_ndims
|
||||||
self._input_shape = input_shape
|
self._input_shape = input_shape
|
||||||
self._output_channels = output_channels
|
self._output_channels = output_channels
|
||||||
self._kernel_shape = kernel_shape
|
self._kernel_shape = list(kernel_shape)
|
||||||
self._use_bias = use_bias
|
self._use_bias = use_bias
|
||||||
self._forget_bias = forget_bias
|
self._forget_bias = forget_bias
|
||||||
self._skip_connection = skip_connection
|
self._skip_connection = skip_connection
|
||||||
@ -2172,7 +2172,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
|
|||||||
Args:
|
Args:
|
||||||
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
|
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
|
||||||
batch x n, Tensors.
|
batch x n, Tensors.
|
||||||
filter_size: int tuple of filter height and width.
|
filter_size: int tuple of filter shape (of size 1, 2 or 3).
|
||||||
num_features: int, number of features.
|
num_features: int, number of features.
|
||||||
bias: Whether to use biases in the convolution layer.
|
bias: Whether to use biases in the convolution layer.
|
||||||
bias_start: starting value to initialize the bias; 0 by default.
|
bias_start: starting value to initialize the bias; 0 by default.
|
||||||
|
Loading…
Reference in New Issue
Block a user