Update ConvLSTMCell to accept shape as list or tuple

PiperOrigin-RevId: 228852641
This commit is contained in:
A. Unique TensorFlower 2019-01-11 02:22:26 -08:00 committed by TensorFlower Gardener
parent c2ab6dc7a1
commit bb87d0021f

View File

@ -2071,7 +2071,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
conv_ndims: Convolution dimensionality (1, 2 or 3).
input_shape: Shape of the input as int tuple, excluding the batch size.
output_channels: int, number of output channels of the conv LSTM.
kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3).
kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3).
use_bias: (bool) Use bias in convolutions.
skip_connection: If set to `True`, concatenate the input to the
output of the conv LSTM. Default: `False`.
@ -2092,7 +2092,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
self._conv_ndims = conv_ndims
self._input_shape = input_shape
self._output_channels = output_channels
self._kernel_shape = kernel_shape
self._kernel_shape = list(kernel_shape)
self._use_bias = use_bias
self._forget_bias = forget_bias
self._skip_connection = skip_connection
@ -2172,7 +2172,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
Args:
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
batch x n, Tensors.
filter_size: int tuple of filter height and width.
filter_size: int tuple of filter shape (of size 1, 2 or 3).
num_features: int, number of features.
bias: Whether to use biases in the convolution layer.
bias_start: starting value to initialize the bias; 0 by default.