diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index 7b8b51c5276..48f724b55e1 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -330,9 +330,10 @@ class ConvRNN2D(RNN): mask=mask, input_length=timesteps) if self.stateful: - updates = [] - for i in range(len(states)): - updates.append(K.update(self.states[i], states[i])) + updates = [ + K.update(self_state, state) + for self_state, state in zip(self.states, states) + ] self.add_update(updates) if self.return_sequences: diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py index ddd9a7ffa04..fc393ea7290 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent.py @@ -110,9 +110,10 @@ class _CuDNNRNN(RNN): output, states = self._process_batch(inputs, initial_state) if self.stateful: - updates = [] - for i in range(len(states)): - updates.append(state_ops.assign(self.states[i], states[i])) + updates = [ + state_ops.assign(self_state, state) + for self_state, state in zip(self.states, states) + ] self.add_update(updates) if self.return_state: diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 87df1cddc05..ad7905f4caa 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -608,16 +608,17 @@ class RNN(Layer): '`cell.state_size`. Received `state_spec`={}; ' 'however `cell.state_size` is ' '{}'.format(init_state_specs, cell_state_sizes)) - flat_cell_state_size = nest.flatten(cell_state_sizes) - flat_state_spec = nest.flatten(init_state_specs) + flat_cell_state_sizes = nest.flatten(cell_state_sizes) + flat_state_specs = nest.flatten(init_state_specs) - if len(flat_cell_state_size) != len(flat_state_spec): + if len(flat_cell_state_sizes) != len(flat_state_specs): raise validation_error - for i in range(len(flat_cell_state_size)): + for cell_state_spec, cell_state_size in zip(flat_state_specs, + flat_cell_state_sizes): if not tensor_shape.TensorShape( # Ignore the first axis for init_state which is for batch - flat_state_spec[i].shape[1:]).is_compatible_with( - tensor_shape.TensorShape(flat_cell_state_size[i])): + cell_state_spec.shape[1:]).is_compatible_with( + tensor_shape.TensorShape(cell_state_size)): raise validation_error @doc_controls.do_not_doc_inheritable @@ -794,9 +795,10 @@ class RNN(Layer): zero_output_for_mask=self.zero_output_for_mask) if self.stateful: - updates = [] - for state_, state in zip(nest.flatten(self.states), nest.flatten(states)): - updates.append(state_ops.assign(state_, state)) + updates = [ + state_ops.assign(self_state, state) for self_state, state in zip( + nest.flatten(self.states), nest.flatten(states)) + ] self.add_update(updates) if self.return_sequences: diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index f50031cd1c3..a9d5ef8587c 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -1179,9 +1179,10 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): states = [new_h, new_c] if self.stateful: - updates = [] - for i in range(len(states)): - updates.append(state_ops.assign(self.states[i], states[i])) + updates = [ + state_ops.assign(self_state, state) + for self_state, state in zip(self.states, states) + ] self.add_update(updates) if self.return_sequences: