From 0d64ea5b718930691daa03a88ce2607f8228a1e6 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Fri, 17 Jan 2020 23:23:15 +0000 Subject: [PATCH] Keras: Simplify recurrent state updates --- .../keras/layers/convolutional_recurrent.py | 6 +++--- .../python/keras/layers/cudnn_recurrent.py | 6 +++--- tensorflow/python/keras/layers/recurrent.py | 20 ++++++++++--------- .../python/keras/layers/recurrent_v2.py | 6 +++--- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index bc4ee3ce5bd..172800a84f9 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -392,9 +392,9 @@ class ConvRNN2D(RNN): mask=mask, input_length=timesteps) if self.stateful: - updates = [] - for i in range(len(states)): - updates.append(K.update(self.states[i], states[i])) + updates = [ + K.update(self_state, state) + for self_state, state in zip(self.states, states)] self.add_update(updates) if self.return_sequences: diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py index ddd9a7ffa04..a82f70f098b 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent.py @@ -110,9 +110,9 @@ class _CuDNNRNN(RNN): output, states = self._process_batch(inputs, initial_state) if self.stateful: - updates = [] - for i in range(len(states)): - updates.append(state_ops.assign(self.states[i], states[i])) + updates = [ + state_ops.assign(self_state, state) + for self_state, state in zip(self.states, states)] self.add_update(updates) if self.return_state: diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 05c3a593c9a..0e21e66a99b 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -601,16 +601,17 @@ class RNN(Layer): '`cell.state_size`. Received `state_spec`={}; ' 'however `cell.state_size` is ' '{}'.format(init_state_specs, cell_state_sizes)) - flat_cell_state_size = nest.flatten(cell_state_sizes) - flat_state_spec = nest.flatten(init_state_specs) + flat_cell_state_sizes = nest.flatten(cell_state_sizes) + flat_state_specs = nest.flatten(init_state_specs) - if len(flat_cell_state_size) != len(flat_state_spec): + if len(flat_cell_state_sizes) != len(flat_state_specs): raise validation_error - for i in range(len(flat_cell_state_size)): + for cell_state_spec, cell_state_size in zip( + flat_state_specs, flat_cell_state_sizes): if not tensor_shape.TensorShape( # Ignore the first axis for init_state which is for batch - flat_state_spec[i].shape[1:]).is_compatible_with( - tensor_shape.TensorShape(flat_cell_state_size[i])): + cell_state_spec.shape[1:]).is_compatible_with( + tensor_shape.TensorShape(cell_state_size)): raise validation_error @doc_controls.do_not_doc_inheritable @@ -785,9 +786,10 @@ class RNN(Layer): zero_output_for_mask=self.zero_output_for_mask) if self.stateful: - updates = [] - for state_, state in zip(nest.flatten(self.states), nest.flatten(states)): - updates.append(state_ops.assign(state_, state)) + updates = [ + state_ops.assign(self_state, state) + for self_state, state in zip( + nest.flatten(self.states), nest.flatten(states))] self.add_update(updates) if self.return_sequences: diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index 68d0884c54b..09b5b89d267 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -1152,9 +1152,9 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): states = [new_h, new_c] if self.stateful: - updates = [] - for i in range(len(states)): - updates.append(state_ops.assign(self.states[i], states[i])) + updates = [ + state_ops.assign(self_state, state) + for self_state, state in zip(self.states, states)] self.add_update(updates) if self.return_sequences: