Have RNN classes pass their dtypes to their cells.

In TF 2, this makes RNNs work properly when a non-float32 dtype is passed to them. ConvLSTM2D is still broken with non-float32 dtypes however, as it calls tf.zeros() in various places without passing the correct dtype.

PiperOrigin-RevId: 261368927
This commit is contained in:
Reed Wanderman-Milne 2019-08-02 12:24:10 -07:00 committed by TensorFlower Gardener
parent 9274e93d4b
commit d90e521d71
7 changed files with 75 additions and 4 deletions

View File

@ -921,7 +921,8 @@ class ConvLSTM2D(ConvRNN2D):
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout)
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'))
super(ConvLSTM2D, self).__init__(cell,
return_sequences=return_sequences,
go_backwards=go_backwards,

View File

@ -43,6 +43,19 @@ class GRULayerTest(keras_parameterized.TestCase):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))
def test_float64_GRU(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
keras.layers.GRU,
kwargs={'units': units,
'return_sequences': True,
'dtype': 'float64'},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')
def test_dynamic_behavior_GRU(self):
num_samples = 2
timesteps = 3

View File

@ -342,6 +342,19 @@ class GRUV2Test(keras_parameterized.TestCase):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))
def test_float64_GRU(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
rnn.GRU,
kwargs={'units': units,
'return_sequences': True,
'dtype': 'float64'},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')
def test_return_states_GRU(self):
layer_class = rnn.GRU
x = np.random.random((2, 3, 4))

View File

@ -44,6 +44,19 @@ class LSTMLayerTest(keras_parameterized.TestCase):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))
def test_float64_LSTM(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
keras.layers.LSTM,
kwargs={'units': units,
'return_sequences': True,
'dtype': 'float64'},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')
def test_static_shape_inference_LSTM(self):
# Github issue: 15165
timesteps = 3

View File

@ -565,6 +565,21 @@ class LSTMV2Test(keras_parameterized.TestCase):
},
input_shape=(num_samples, timesteps, embedding_dim))
def test_float64_LSTM(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
rnn.LSTM,
kwargs={
'units': units,
'return_sequences': True,
'dtype': 'float64'
},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')
def test_regularizers_LSTM(self):
embedding_dim = 4
layer_class = rnn.LSTM

View File

@ -1362,7 +1362,8 @@ class SimpleRNN(RNN):
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout)
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'))
super(SimpleRNN, self).__init__(
cell,
return_sequences=return_sequences,
@ -1890,7 +1891,8 @@ class GRU(RNN):
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=implementation,
reset_after=reset_after)
reset_after=reset_after,
dtype=kwargs.get('dtype'))
super(GRU, self).__init__(
cell,
return_sequences=return_sequences,
@ -2516,7 +2518,8 @@ class LSTM(RNN):
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=implementation)
implementation=implementation,
dtype=kwargs.get('dtype'))
super(LSTM, self).__init__(
cell,
return_sequences=return_sequences,

View File

@ -42,6 +42,19 @@ class SimpleRNNLayerTest(keras_parameterized.TestCase):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))
def test_float64_SimpleRNN(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
keras.layers.SimpleRNN,
kwargs={'units': units,
'return_sequences': True,
'dtype': 'float64'},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')
def test_dynamic_behavior_SimpleRNN(self):
num_samples = 2
timesteps = 3