Update ConvLSTMCell to accept shape as list or tuple
PiperOrigin-RevId: 228852641
This commit is contained in:
parent
c2ab6dc7a1
commit
bb87d0021f
@ -2071,7 +2071,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
|
||||
conv_ndims: Convolution dimensionality (1, 2 or 3).
|
||||
input_shape: Shape of the input as int tuple, excluding the batch size.
|
||||
output_channels: int, number of output channels of the conv LSTM.
|
||||
kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3).
|
||||
kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3).
|
||||
use_bias: (bool) Use bias in convolutions.
|
||||
skip_connection: If set to `True`, concatenate the input to the
|
||||
output of the conv LSTM. Default: `False`.
|
||||
@ -2092,7 +2092,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
|
||||
self._conv_ndims = conv_ndims
|
||||
self._input_shape = input_shape
|
||||
self._output_channels = output_channels
|
||||
self._kernel_shape = kernel_shape
|
||||
self._kernel_shape = list(kernel_shape)
|
||||
self._use_bias = use_bias
|
||||
self._forget_bias = forget_bias
|
||||
self._skip_connection = skip_connection
|
||||
@ -2172,7 +2172,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
|
||||
Args:
|
||||
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
|
||||
batch x n, Tensors.
|
||||
filter_size: int tuple of filter height and width.
|
||||
filter_size: int tuple of filter shape (of size 1, 2 or 3).
|
||||
num_features: int, number of features.
|
||||
bias: Whether to use biases in the convolution layer.
|
||||
bias_start: starting value to initialize the bias; 0 by default.
|
||||
|
Loading…
Reference in New Issue
Block a user