Making sure GLSTMCell is visible through tf.contrib.rnn.GLSTMCell (#9704)

* Making sure GLSTMCell is visible through tf.contrib.rnn.GLSTMCell

* GLSTM: better way to infer batch size
This commit is contained in:
Oleksii Kuchaiev 2017-05-05 18:32:13 -07:00 committed by Vijay Vasudevan
parent 6bfbcf31dc
commit 6fca7d1037
2 changed files with 3 additions and 5 deletions

View File

@ -46,6 +46,7 @@ See @{$python/contrib.rnn} guide.
@@IntersectionRNNCell
@@PhasedLSTMCell
@@HighwayWrapper
@@GLSTMCell
### RNNCell wrappers
@@AttentionCellWrapper

View File

@ -2008,10 +2008,9 @@ class GLSTMCell(core_rnn_cell.RNNCell):
subset of inputs corresponding to group "group_id",
a Tensor, 2D, [batch x num_units/number_of_groups]
"""
batch_size = inputs.shape[0].value or array_ops.shape(value)[0]
return array_ops.slice(input_=inputs,
begin=[0, group_id * group_size],
size=[batch_size, group_size],
size=[self._batch_size, group_size],
name=("GLSTM_group%d_input_generation" % group_id))
def call(self, inputs, state):
@ -2039,9 +2038,7 @@ class GLSTMCell(core_rnn_cell.RNNCell):
"""
(c_prev, m_prev) = state
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Couldn't infer input size from inputs.get_shape()[-1]")
self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
dtype = inputs.dtype
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer):