From 6fca7d103745084480e8c909d539f17f78b0ab6f Mon Sep 17 00:00:00 2001 From: Oleksii Kuchaiev Date: Fri, 5 May 2017 18:32:13 -0700 Subject: [PATCH] Making sure GLSTMCell is visible through tf.contrib.rnn.GLSTMCell (#9704) * Making sure GLSTMCell is visible through tf.contrib.rnn.GLSTMCell * GLSTM: better way to infer batch size --- tensorflow/contrib/rnn/__init__.py | 1 + tensorflow/contrib/rnn/python/ops/rnn_cell.py | 7 ++----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index d2f9cde4f6b..2420c3e179b 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -46,6 +46,7 @@ See @{$python/contrib.rnn} guide. @@IntersectionRNNCell @@PhasedLSTMCell @@HighwayWrapper +@@GLSTMCell ### RNNCell wrappers @@AttentionCellWrapper diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 9672b8b85f0..7a0f894404c 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -2008,10 +2008,9 @@ class GLSTMCell(core_rnn_cell.RNNCell): subset of inputs corresponding to group "group_id", a Tensor, 2D, [batch x num_units/number_of_groups] """ - batch_size = inputs.shape[0].value or array_ops.shape(value)[0] return array_ops.slice(input_=inputs, begin=[0, group_id * group_size], - size=[batch_size, group_size], + size=[self._batch_size, group_size], name=("GLSTM_group%d_input_generation" % group_id)) def call(self, inputs, state): @@ -2039,9 +2038,7 @@ class GLSTMCell(core_rnn_cell.RNNCell): """ (c_prev, m_prev) = state - input_size = inputs.get_shape().with_rank(2)[1] - if input_size.value is None: - raise ValueError("Couldn't infer input size from inputs.get_shape()[-1]") + self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] dtype = inputs.dtype scope = vs.get_variable_scope() with vs.variable_scope(scope, initializer=self._initializer):