Merge pull request #26216 from lgeiger:keras-recurrent-updates

PiperOrigin-RevId: 301196514
Change-Id: I0b1d0f13e9bdfb87391b843f7245bb2a8713f605
This commit is contained in:
TensorFlower Gardener 2020-03-16 11:13:23 -07:00
commit e772406b52
4 changed files with 23 additions and 18 deletions

View File

@ -330,9 +330,10 @@ class ConvRNN2D(RNN):
mask=mask,
input_length=timesteps)
if self.stateful:
updates = []
for i in range(len(states)):
updates.append(K.update(self.states[i], states[i]))
updates = [
K.update(self_state, state)
for self_state, state in zip(self.states, states)
]
self.add_update(updates)
if self.return_sequences:

View File

@ -110,9 +110,10 @@ class _CuDNNRNN(RNN):
output, states = self._process_batch(inputs, initial_state)
if self.stateful:
updates = []
for i in range(len(states)):
updates.append(state_ops.assign(self.states[i], states[i]))
updates = [
state_ops.assign(self_state, state)
for self_state, state in zip(self.states, states)
]
self.add_update(updates)
if self.return_state:

View File

@ -608,16 +608,17 @@ class RNN(Layer):
'`cell.state_size`. Received `state_spec`={}; '
'however `cell.state_size` is '
'{}'.format(init_state_specs, cell_state_sizes))
flat_cell_state_size = nest.flatten(cell_state_sizes)
flat_state_spec = nest.flatten(init_state_specs)
flat_cell_state_sizes = nest.flatten(cell_state_sizes)
flat_state_specs = nest.flatten(init_state_specs)
if len(flat_cell_state_size) != len(flat_state_spec):
if len(flat_cell_state_sizes) != len(flat_state_specs):
raise validation_error
for i in range(len(flat_cell_state_size)):
for cell_state_spec, cell_state_size in zip(flat_state_specs,
flat_cell_state_sizes):
if not tensor_shape.TensorShape(
# Ignore the first axis for init_state which is for batch
flat_state_spec[i].shape[1:]).is_compatible_with(
tensor_shape.TensorShape(flat_cell_state_size[i])):
cell_state_spec.shape[1:]).is_compatible_with(
tensor_shape.TensorShape(cell_state_size)):
raise validation_error
@doc_controls.do_not_doc_inheritable
@ -794,9 +795,10 @@ class RNN(Layer):
zero_output_for_mask=self.zero_output_for_mask)
if self.stateful:
updates = []
for state_, state in zip(nest.flatten(self.states), nest.flatten(states)):
updates.append(state_ops.assign(state_, state))
updates = [
state_ops.assign(self_state, state) for self_state, state in zip(
nest.flatten(self.states), nest.flatten(states))
]
self.add_update(updates)
if self.return_sequences:

View File

@ -1179,9 +1179,10 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
states = [new_h, new_c]
if self.stateful:
updates = []
for i in range(len(states)):
updates.append(state_ops.assign(self.states[i], states[i]))
updates = [
state_ops.assign(self_state, state)
for self_state, state in zip(self.states, states)
]
self.add_update(updates)
if self.return_sequences: