Have RNN classes pass their dtypes to their cells.
In TF 2, this makes RNNs work properly when a non-float32 dtype is passed to them. ConvLSTM2D is still broken with non-float32 dtypes however, as it calls tf.zeros() in various places without passing the correct dtype. PiperOrigin-RevId: 261368927
This commit is contained in:
parent
9274e93d4b
commit
d90e521d71
@ -921,7 +921,8 @@ class ConvLSTM2D(ConvRNN2D):
|
||||
recurrent_constraint=recurrent_constraint,
|
||||
bias_constraint=bias_constraint,
|
||||
dropout=dropout,
|
||||
recurrent_dropout=recurrent_dropout)
|
||||
recurrent_dropout=recurrent_dropout,
|
||||
dtype=kwargs.get('dtype'))
|
||||
super(ConvLSTM2D, self).__init__(cell,
|
||||
return_sequences=return_sequences,
|
||||
go_backwards=go_backwards,
|
||||
|
@ -43,6 +43,19 @@ class GRULayerTest(keras_parameterized.TestCase):
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_float64_GRU(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
testing_utils.layer_test(
|
||||
keras.layers.GRU,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True,
|
||||
'dtype': 'float64'},
|
||||
input_shape=(num_samples, timesteps, embedding_dim),
|
||||
input_dtype='float64')
|
||||
|
||||
def test_dynamic_behavior_GRU(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
|
@ -342,6 +342,19 @@ class GRUV2Test(keras_parameterized.TestCase):
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_float64_GRU(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
testing_utils.layer_test(
|
||||
rnn.GRU,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True,
|
||||
'dtype': 'float64'},
|
||||
input_shape=(num_samples, timesteps, embedding_dim),
|
||||
input_dtype='float64')
|
||||
|
||||
def test_return_states_GRU(self):
|
||||
layer_class = rnn.GRU
|
||||
x = np.random.random((2, 3, 4))
|
||||
|
@ -44,6 +44,19 @@ class LSTMLayerTest(keras_parameterized.TestCase):
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_float64_LSTM(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LSTM,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True,
|
||||
'dtype': 'float64'},
|
||||
input_shape=(num_samples, timesteps, embedding_dim),
|
||||
input_dtype='float64')
|
||||
|
||||
def test_static_shape_inference_LSTM(self):
|
||||
# Github issue: 15165
|
||||
timesteps = 3
|
||||
|
@ -565,6 +565,21 @@ class LSTMV2Test(keras_parameterized.TestCase):
|
||||
},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_float64_LSTM(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
testing_utils.layer_test(
|
||||
rnn.LSTM,
|
||||
kwargs={
|
||||
'units': units,
|
||||
'return_sequences': True,
|
||||
'dtype': 'float64'
|
||||
},
|
||||
input_shape=(num_samples, timesteps, embedding_dim),
|
||||
input_dtype='float64')
|
||||
|
||||
def test_regularizers_LSTM(self):
|
||||
embedding_dim = 4
|
||||
layer_class = rnn.LSTM
|
||||
|
@ -1362,7 +1362,8 @@ class SimpleRNN(RNN):
|
||||
recurrent_constraint=recurrent_constraint,
|
||||
bias_constraint=bias_constraint,
|
||||
dropout=dropout,
|
||||
recurrent_dropout=recurrent_dropout)
|
||||
recurrent_dropout=recurrent_dropout,
|
||||
dtype=kwargs.get('dtype'))
|
||||
super(SimpleRNN, self).__init__(
|
||||
cell,
|
||||
return_sequences=return_sequences,
|
||||
@ -1890,7 +1891,8 @@ class GRU(RNN):
|
||||
dropout=dropout,
|
||||
recurrent_dropout=recurrent_dropout,
|
||||
implementation=implementation,
|
||||
reset_after=reset_after)
|
||||
reset_after=reset_after,
|
||||
dtype=kwargs.get('dtype'))
|
||||
super(GRU, self).__init__(
|
||||
cell,
|
||||
return_sequences=return_sequences,
|
||||
@ -2516,7 +2518,8 @@ class LSTM(RNN):
|
||||
bias_constraint=bias_constraint,
|
||||
dropout=dropout,
|
||||
recurrent_dropout=recurrent_dropout,
|
||||
implementation=implementation)
|
||||
implementation=implementation,
|
||||
dtype=kwargs.get('dtype'))
|
||||
super(LSTM, self).__init__(
|
||||
cell,
|
||||
return_sequences=return_sequences,
|
||||
|
@ -42,6 +42,19 @@ class SimpleRNNLayerTest(keras_parameterized.TestCase):
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_float64_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SimpleRNN,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True,
|
||||
'dtype': 'float64'},
|
||||
input_shape=(num_samples, timesteps, embedding_dim),
|
||||
input_dtype='float64')
|
||||
|
||||
def test_dynamic_behavior_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
|
Loading…
Reference in New Issue
Block a user