Making sure GLSTMCell is visible through tf.contrib.rnn.GLSTMCell (#9704)
* Making sure GLSTMCell is visible through tf.contrib.rnn.GLSTMCell * GLSTM: better way to infer batch size
This commit is contained in:
parent
6bfbcf31dc
commit
6fca7d1037
@ -46,6 +46,7 @@ See @{$python/contrib.rnn} guide.
|
|||||||
@@IntersectionRNNCell
|
@@IntersectionRNNCell
|
||||||
@@PhasedLSTMCell
|
@@PhasedLSTMCell
|
||||||
@@HighwayWrapper
|
@@HighwayWrapper
|
||||||
|
@@GLSTMCell
|
||||||
|
|
||||||
### RNNCell wrappers
|
### RNNCell wrappers
|
||||||
@@AttentionCellWrapper
|
@@AttentionCellWrapper
|
||||||
|
@ -2008,10 +2008,9 @@ class GLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
subset of inputs corresponding to group "group_id",
|
subset of inputs corresponding to group "group_id",
|
||||||
a Tensor, 2D, [batch x num_units/number_of_groups]
|
a Tensor, 2D, [batch x num_units/number_of_groups]
|
||||||
"""
|
"""
|
||||||
batch_size = inputs.shape[0].value or array_ops.shape(value)[0]
|
|
||||||
return array_ops.slice(input_=inputs,
|
return array_ops.slice(input_=inputs,
|
||||||
begin=[0, group_id * group_size],
|
begin=[0, group_id * group_size],
|
||||||
size=[batch_size, group_size],
|
size=[self._batch_size, group_size],
|
||||||
name=("GLSTM_group%d_input_generation" % group_id))
|
name=("GLSTM_group%d_input_generation" % group_id))
|
||||||
|
|
||||||
def call(self, inputs, state):
|
def call(self, inputs, state):
|
||||||
@ -2039,9 +2038,7 @@ class GLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
"""
|
"""
|
||||||
(c_prev, m_prev) = state
|
(c_prev, m_prev) = state
|
||||||
|
|
||||||
input_size = inputs.get_shape().with_rank(2)[1]
|
self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
|
||||||
if input_size.value is None:
|
|
||||||
raise ValueError("Couldn't infer input size from inputs.get_shape()[-1]")
|
|
||||||
dtype = inputs.dtype
|
dtype = inputs.dtype
|
||||||
scope = vs.get_variable_scope()
|
scope = vs.get_variable_scope()
|
||||||
with vs.variable_scope(scope, initializer=self._initializer):
|
with vs.variable_scope(scope, initializer=self._initializer):
|
||||||
|
Loading…
Reference in New Issue
Block a user