Making sure GLSTMCell is visible through tf.contrib.rnn.GLSTMCell (#9704)
* Making sure GLSTMCell is visible through tf.contrib.rnn.GLSTMCell * GLSTM: better way to infer batch size
This commit is contained in:
parent
6bfbcf31dc
commit
6fca7d1037
@ -46,6 +46,7 @@ See @{$python/contrib.rnn} guide.
|
||||
@@IntersectionRNNCell
|
||||
@@PhasedLSTMCell
|
||||
@@HighwayWrapper
|
||||
@@GLSTMCell
|
||||
|
||||
### RNNCell wrappers
|
||||
@@AttentionCellWrapper
|
||||
|
@ -2008,10 +2008,9 @@ class GLSTMCell(core_rnn_cell.RNNCell):
|
||||
subset of inputs corresponding to group "group_id",
|
||||
a Tensor, 2D, [batch x num_units/number_of_groups]
|
||||
"""
|
||||
batch_size = inputs.shape[0].value or array_ops.shape(value)[0]
|
||||
return array_ops.slice(input_=inputs,
|
||||
begin=[0, group_id * group_size],
|
||||
size=[batch_size, group_size],
|
||||
size=[self._batch_size, group_size],
|
||||
name=("GLSTM_group%d_input_generation" % group_id))
|
||||
|
||||
def call(self, inputs, state):
|
||||
@ -2039,9 +2038,7 @@ class GLSTMCell(core_rnn_cell.RNNCell):
|
||||
"""
|
||||
(c_prev, m_prev) = state
|
||||
|
||||
input_size = inputs.get_shape().with_rank(2)[1]
|
||||
if input_size.value is None:
|
||||
raise ValueError("Couldn't infer input size from inputs.get_shape()[-1]")
|
||||
self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
|
||||
dtype = inputs.dtype
|
||||
scope = vs.get_variable_scope()
|
||||
with vs.variable_scope(scope, initializer=self._initializer):
|
||||
|
Loading…
Reference in New Issue
Block a user