Merge pull request #26216 from lgeiger:keras-recurrent-updates

PiperOrigin-RevId: 301196514
Change-Id: I0b1d0f13e9bdfb87391b843f7245bb2a8713f605
This commit is contained in:
TensorFlower Gardener 2020-03-16 11:13:23 -07:00
commit e772406b52
4 changed files with 23 additions and 18 deletions

View File

@ -330,9 +330,10 @@ class ConvRNN2D(RNN):
mask=mask, mask=mask,
input_length=timesteps) input_length=timesteps)
if self.stateful: if self.stateful:
updates = [] updates = [
for i in range(len(states)): K.update(self_state, state)
updates.append(K.update(self.states[i], states[i])) for self_state, state in zip(self.states, states)
]
self.add_update(updates) self.add_update(updates)
if self.return_sequences: if self.return_sequences:

View File

@ -110,9 +110,10 @@ class _CuDNNRNN(RNN):
output, states = self._process_batch(inputs, initial_state) output, states = self._process_batch(inputs, initial_state)
if self.stateful: if self.stateful:
updates = [] updates = [
for i in range(len(states)): state_ops.assign(self_state, state)
updates.append(state_ops.assign(self.states[i], states[i])) for self_state, state in zip(self.states, states)
]
self.add_update(updates) self.add_update(updates)
if self.return_state: if self.return_state:

View File

@ -608,16 +608,17 @@ class RNN(Layer):
'`cell.state_size`. Received `state_spec`={}; ' '`cell.state_size`. Received `state_spec`={}; '
'however `cell.state_size` is ' 'however `cell.state_size` is '
'{}'.format(init_state_specs, cell_state_sizes)) '{}'.format(init_state_specs, cell_state_sizes))
flat_cell_state_size = nest.flatten(cell_state_sizes) flat_cell_state_sizes = nest.flatten(cell_state_sizes)
flat_state_spec = nest.flatten(init_state_specs) flat_state_specs = nest.flatten(init_state_specs)
if len(flat_cell_state_size) != len(flat_state_spec): if len(flat_cell_state_sizes) != len(flat_state_specs):
raise validation_error raise validation_error
for i in range(len(flat_cell_state_size)): for cell_state_spec, cell_state_size in zip(flat_state_specs,
flat_cell_state_sizes):
if not tensor_shape.TensorShape( if not tensor_shape.TensorShape(
# Ignore the first axis for init_state which is for batch # Ignore the first axis for init_state which is for batch
flat_state_spec[i].shape[1:]).is_compatible_with( cell_state_spec.shape[1:]).is_compatible_with(
tensor_shape.TensorShape(flat_cell_state_size[i])): tensor_shape.TensorShape(cell_state_size)):
raise validation_error raise validation_error
@doc_controls.do_not_doc_inheritable @doc_controls.do_not_doc_inheritable
@ -794,9 +795,10 @@ class RNN(Layer):
zero_output_for_mask=self.zero_output_for_mask) zero_output_for_mask=self.zero_output_for_mask)
if self.stateful: if self.stateful:
updates = [] updates = [
for state_, state in zip(nest.flatten(self.states), nest.flatten(states)): state_ops.assign(self_state, state) for self_state, state in zip(
updates.append(state_ops.assign(state_, state)) nest.flatten(self.states), nest.flatten(states))
]
self.add_update(updates) self.add_update(updates)
if self.return_sequences: if self.return_sequences:

View File

@ -1179,9 +1179,10 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
states = [new_h, new_c] states = [new_h, new_c]
if self.stateful: if self.stateful:
updates = [] updates = [
for i in range(len(states)): state_ops.assign(self_state, state)
updates.append(state_ops.assign(self.states[i], states[i])) for self_state, state in zip(self.states, states)
]
self.add_update(updates) self.add_update(updates)
if self.return_sequences: if self.return_sequences: