Merge pull request #26216 from lgeiger:keras-recurrent-updates
PiperOrigin-RevId: 301196514 Change-Id: I0b1d0f13e9bdfb87391b843f7245bb2a8713f605
This commit is contained in:
commit
e772406b52
@ -330,9 +330,10 @@ class ConvRNN2D(RNN):
|
||||
mask=mask,
|
||||
input_length=timesteps)
|
||||
if self.stateful:
|
||||
updates = []
|
||||
for i in range(len(states)):
|
||||
updates.append(K.update(self.states[i], states[i]))
|
||||
updates = [
|
||||
K.update(self_state, state)
|
||||
for self_state, state in zip(self.states, states)
|
||||
]
|
||||
self.add_update(updates)
|
||||
|
||||
if self.return_sequences:
|
||||
|
@ -110,9 +110,10 @@ class _CuDNNRNN(RNN):
|
||||
output, states = self._process_batch(inputs, initial_state)
|
||||
|
||||
if self.stateful:
|
||||
updates = []
|
||||
for i in range(len(states)):
|
||||
updates.append(state_ops.assign(self.states[i], states[i]))
|
||||
updates = [
|
||||
state_ops.assign(self_state, state)
|
||||
for self_state, state in zip(self.states, states)
|
||||
]
|
||||
self.add_update(updates)
|
||||
|
||||
if self.return_state:
|
||||
|
@ -608,16 +608,17 @@ class RNN(Layer):
|
||||
'`cell.state_size`. Received `state_spec`={}; '
|
||||
'however `cell.state_size` is '
|
||||
'{}'.format(init_state_specs, cell_state_sizes))
|
||||
flat_cell_state_size = nest.flatten(cell_state_sizes)
|
||||
flat_state_spec = nest.flatten(init_state_specs)
|
||||
flat_cell_state_sizes = nest.flatten(cell_state_sizes)
|
||||
flat_state_specs = nest.flatten(init_state_specs)
|
||||
|
||||
if len(flat_cell_state_size) != len(flat_state_spec):
|
||||
if len(flat_cell_state_sizes) != len(flat_state_specs):
|
||||
raise validation_error
|
||||
for i in range(len(flat_cell_state_size)):
|
||||
for cell_state_spec, cell_state_size in zip(flat_state_specs,
|
||||
flat_cell_state_sizes):
|
||||
if not tensor_shape.TensorShape(
|
||||
# Ignore the first axis for init_state which is for batch
|
||||
flat_state_spec[i].shape[1:]).is_compatible_with(
|
||||
tensor_shape.TensorShape(flat_cell_state_size[i])):
|
||||
cell_state_spec.shape[1:]).is_compatible_with(
|
||||
tensor_shape.TensorShape(cell_state_size)):
|
||||
raise validation_error
|
||||
|
||||
@doc_controls.do_not_doc_inheritable
|
||||
@ -794,9 +795,10 @@ class RNN(Layer):
|
||||
zero_output_for_mask=self.zero_output_for_mask)
|
||||
|
||||
if self.stateful:
|
||||
updates = []
|
||||
for state_, state in zip(nest.flatten(self.states), nest.flatten(states)):
|
||||
updates.append(state_ops.assign(state_, state))
|
||||
updates = [
|
||||
state_ops.assign(self_state, state) for self_state, state in zip(
|
||||
nest.flatten(self.states), nest.flatten(states))
|
||||
]
|
||||
self.add_update(updates)
|
||||
|
||||
if self.return_sequences:
|
||||
|
@ -1179,9 +1179,10 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
|
||||
states = [new_h, new_c]
|
||||
|
||||
if self.stateful:
|
||||
updates = []
|
||||
for i in range(len(states)):
|
||||
updates.append(state_ops.assign(self.states[i], states[i]))
|
||||
updates = [
|
||||
state_ops.assign(self_state, state)
|
||||
for self_state, state in zip(self.states, states)
|
||||
]
|
||||
self.add_update(updates)
|
||||
|
||||
if self.return_sequences:
|
||||
|
Loading…
x
Reference in New Issue
Block a user