[RNN] Update TF GRU and BasicRNN cell to support Keras RNN layer.
Keras RNN layer will feed in state_t as a list, and expect state_t+1 as list as well, even the state is a single tensor for simple cell. On the other hand, TF RNN does not wrap single tensor state. PiperOrigin-RevId: 208666836
This commit is contained in:
		
							parent
							
								
									4be575a6c2
								
							
						
					
					
						commit
						87e5662309
					
				| @ -670,6 +670,8 @@ class RNN(Layer): | ||||
|     if generic_utils.has_arg(self.cell.call, 'training'): | ||||
|       kwargs['training'] = training | ||||
| 
 | ||||
|     # TF RNN cells expect single tensor as state instead of list wrapped tensor. | ||||
|     is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None | ||||
|     if constants: | ||||
|       if not generic_utils.has_arg(self.cell.call, 'constants'): | ||||
|         raise ValueError('RNN cell does not support constants') | ||||
| @ -677,11 +679,21 @@ class RNN(Layer): | ||||
|       def step(inputs, states): | ||||
|         constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type | ||||
|         states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type | ||||
|         return self.cell.call(inputs, states, constants=constants, **kwargs) | ||||
| 
 | ||||
|         states = states[0] if len(states) == 1 and is_tf_rnn_cell else states | ||||
|         output, new_states = self.cell.call( | ||||
|             inputs, states, constants=constants, **kwargs) | ||||
|         if not nest.is_sequence(new_states): | ||||
|           new_states = [new_states] | ||||
|         return output, new_states | ||||
|     else: | ||||
| 
 | ||||
|       def step(inputs, states): | ||||
|         return self.cell.call(inputs, states, **kwargs) | ||||
|         states = states[0] if len(states) == 1 and is_tf_rnn_cell else states | ||||
|         output, new_states = self.cell.call(inputs, states, **kwargs) | ||||
|         if not nest.is_sequence(new_states): | ||||
|           new_states = [new_states] | ||||
|         return output, new_states | ||||
| 
 | ||||
|     last_output, outputs, states = K.rnn( | ||||
|         step, | ||||
|  | ||||
| @ -301,14 +301,12 @@ class RNNTest(test.TestCase): | ||||
|       self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias)) | ||||
| 
 | ||||
|   def testRNNCellSerialization(self): | ||||
|     for cell in  [ | ||||
|     for cell in [ | ||||
|         rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True), | ||||
|         rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32), | ||||
|         # TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras. | ||||
|         # rnn_cell_impl.BasicRNNCell( | ||||
|         #     32, activation="relu", dtype=dtypes.float32), | ||||
|         # rnn_cell_impl.GRUCell( | ||||
|         #     32, kernel_initializer="ones", dtype=dtypes.float32) | ||||
|         rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32), | ||||
|         rnn_cell_impl.GRUCell( | ||||
|             32, kernel_initializer="ones", dtype=dtypes.float32) | ||||
|     ]: | ||||
|       with self.test_session(): | ||||
|         x = keras.Input((None, 5)) | ||||
| @ -326,11 +324,13 @@ class RNNTest(test.TestCase): | ||||
|         # not visible as a Keras layer, and also has a name conflict with | ||||
|         # keras.LSTMCell and GRUCell. | ||||
|         layer = keras.layers.RNN.from_config( | ||||
|             config, custom_objects={ | ||||
|                 # "BasicRNNCell": rnn_cell_impl.BasicRNNCell, | ||||
|                 # "GRUCell": rnn_cell_impl.GRUCell, | ||||
|             config, | ||||
|             custom_objects={ | ||||
|                 "BasicRNNCell": rnn_cell_impl.BasicRNNCell, | ||||
|                 "GRUCell": rnn_cell_impl.GRUCell, | ||||
|                 "LSTMCell": rnn_cell_impl.LSTMCell, | ||||
|                 "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell}) | ||||
|                 "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell | ||||
|             }) | ||||
|         y = layer(x) | ||||
|         model = keras.models.Model(x, y) | ||||
|         model.set_weights(weights) | ||||
|  | ||||
| @ -193,6 +193,13 @@ class RNNCell(base_layer.Layer): | ||||
|   for each `s` in `self.batch_size`. | ||||
|   """ | ||||
| 
 | ||||
|   def __init__(self, trainable=True, name=None, dtype=None, **kwargs): | ||||
|     super(RNNCell, self).__init__( | ||||
|         trainable=trainable, name=name, dtype=dtype, **kwargs) | ||||
|     # Attribute that indicates whether the cell is a TF RNN cell, due the slight | ||||
|     # difference between TF and Keras RNN cell. | ||||
|     self._is_tf_rnn_cell = True | ||||
| 
 | ||||
|   def __call__(self, inputs, state, scope=None): | ||||
|     """Run this RNN cell on inputs, starting from the given state. | ||||
| 
 | ||||
| @ -524,8 +531,8 @@ class GRUCell(LayerRNNCell): | ||||
|   def get_config(self): | ||||
|     config = { | ||||
|         "num_units": self._num_units, | ||||
|         "initializer": initializers.serialize(self._initializer), | ||||
|         "kernel_initializer": initializers.serialize(self._kernel_initializer), | ||||
|         "bias_initializer": initializers.serialize(self._bias_initializer), | ||||
|         "activation": activations.serialize(self._activation), | ||||
|         "reuse": self._reuse, | ||||
|     } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user