Keras: Simplify recurrent state updates

This commit is contained in:
Lukas Geiger 2020-01-17 23:23:15 +00:00
parent 0197b66905
commit 0d64ea5b71
4 changed files with 20 additions and 18 deletions

View File

@ -392,9 +392,9 @@ class ConvRNN2D(RNN):
mask=mask,
input_length=timesteps)
if self.stateful:
updates = []
for i in range(len(states)):
updates.append(K.update(self.states[i], states[i]))
updates = [
K.update(self_state, state)
for self_state, state in zip(self.states, states)]
self.add_update(updates)
if self.return_sequences:

View File

@ -110,9 +110,9 @@ class _CuDNNRNN(RNN):
output, states = self._process_batch(inputs, initial_state)
if self.stateful:
updates = []
for i in range(len(states)):
updates.append(state_ops.assign(self.states[i], states[i]))
updates = [
state_ops.assign(self_state, state)
for self_state, state in zip(self.states, states)]
self.add_update(updates)
if self.return_state:

View File

@ -601,16 +601,17 @@ class RNN(Layer):
'`cell.state_size`. Received `state_spec`={}; '
'however `cell.state_size` is '
'{}'.format(init_state_specs, cell_state_sizes))
flat_cell_state_size = nest.flatten(cell_state_sizes)
flat_state_spec = nest.flatten(init_state_specs)
flat_cell_state_sizes = nest.flatten(cell_state_sizes)
flat_state_specs = nest.flatten(init_state_specs)
if len(flat_cell_state_size) != len(flat_state_spec):
if len(flat_cell_state_sizes) != len(flat_state_specs):
raise validation_error
for i in range(len(flat_cell_state_size)):
for cell_state_spec, cell_state_size in zip(
flat_state_specs, flat_cell_state_sizes):
if not tensor_shape.TensorShape(
# Ignore the first axis for init_state which is for batch
flat_state_spec[i].shape[1:]).is_compatible_with(
tensor_shape.TensorShape(flat_cell_state_size[i])):
cell_state_spec.shape[1:]).is_compatible_with(
tensor_shape.TensorShape(cell_state_size)):
raise validation_error
@doc_controls.do_not_doc_inheritable
@ -785,9 +786,10 @@ class RNN(Layer):
zero_output_for_mask=self.zero_output_for_mask)
if self.stateful:
updates = []
for state_, state in zip(nest.flatten(self.states), nest.flatten(states)):
updates.append(state_ops.assign(state_, state))
updates = [
state_ops.assign(self_state, state)
for self_state, state in zip(
nest.flatten(self.states), nest.flatten(states))]
self.add_update(updates)
if self.return_sequences:

View File

@ -1152,9 +1152,9 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
states = [new_h, new_c]
if self.stateful:
updates = []
for i in range(len(states)):
updates.append(state_ops.assign(self.states[i], states[i]))
updates = [
state_ops.assign(self_state, state)
for self_state, state in zip(self.states, states)]
self.add_update(updates)
if self.return_sequences: