Update ConvLSTMCell to accept shape as list or tuple

PiperOrigin-RevId: 228852641
This commit is contained in:
A. Unique TensorFlower 2019-01-11 02:22:26 -08:00 committed by TensorFlower Gardener
parent c2ab6dc7a1
commit bb87d0021f

View File

@ -2071,7 +2071,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
conv_ndims: Convolution dimensionality (1, 2 or 3). conv_ndims: Convolution dimensionality (1, 2 or 3).
input_shape: Shape of the input as int tuple, excluding the batch size. input_shape: Shape of the input as int tuple, excluding the batch size.
output_channels: int, number of output channels of the conv LSTM. output_channels: int, number of output channels of the conv LSTM.
kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3).
use_bias: (bool) Use bias in convolutions. use_bias: (bool) Use bias in convolutions.
skip_connection: If set to `True`, concatenate the input to the skip_connection: If set to `True`, concatenate the input to the
output of the conv LSTM. Default: `False`. output of the conv LSTM. Default: `False`.
@ -2092,7 +2092,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
self._conv_ndims = conv_ndims self._conv_ndims = conv_ndims
self._input_shape = input_shape self._input_shape = input_shape
self._output_channels = output_channels self._output_channels = output_channels
self._kernel_shape = kernel_shape self._kernel_shape = list(kernel_shape)
self._use_bias = use_bias self._use_bias = use_bias
self._forget_bias = forget_bias self._forget_bias = forget_bias
self._skip_connection = skip_connection self._skip_connection = skip_connection
@ -2172,7 +2172,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
Args: Args:
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
batch x n, Tensors. batch x n, Tensors.
filter_size: int tuple of filter height and width. filter_size: int tuple of filter shape (of size 1, 2 or 3).
num_features: int, number of features. num_features: int, number of features.
bias: Whether to use biases in the convolution layer. bias: Whether to use biases in the convolution layer.
bias_start: starting value to initialize the bias; 0 by default. bias_start: starting value to initialize the bias; 0 by default.