From 9650c977e3b3ca73a0249a08f703960f478b3ea9 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Wed, 27 Feb 2019 15:48:45 -0800 Subject: [PATCH] Fix RNN cell mask reuse in true eager mode. PiperOrigin-RevId: 236008618 --- tensorflow/python/keras/backend.py | 7 +- .../keras/layers/convolutional_recurrent.py | 33 +-- .../layers/convolutional_recurrent_test.py | 2 - tensorflow/python/keras/layers/recurrent.py | 246 +++++++++++------- .../python/keras/layers/recurrent_test.py | 73 +++++- .../python/keras/layers/unified_lstm_test.py | 32 +-- ....experimental.-peephole-l-s-t-m-cell.pbtxt | 17 ++ .../tensorflow.keras.layers.-g-r-u-cell.pbtxt | 17 ++ ...ensorflow.keras.layers.-l-s-t-m-cell.pbtxt | 17 ++ ...flow.keras.layers.-simple-r-n-n-cell.pbtxt | 17 ++ ....experimental.-peephole-l-s-t-m-cell.pbtxt | 17 ++ .../tensorflow.keras.layers.-g-r-u-cell.pbtxt | 17 ++ .../v2/tensorflow.keras.layers.-g-r-u.pbtxt | 17 ++ ...ensorflow.keras.layers.-l-s-t-m-cell.pbtxt | 17 ++ .../v2/tensorflow.keras.layers.-l-s-t-m.pbtxt | 17 ++ ...flow.keras.layers.-simple-r-n-n-cell.pbtxt | 17 ++ 16 files changed, 421 insertions(+), 142 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index b63127c26fc..fe97f746b07 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -3402,7 +3402,7 @@ def rnn(step_function, if unroll: if not time_steps: raise ValueError('Unrolling requires a fixed number of timesteps.') - states = initial_states + states = tuple(initial_states) successive_states = [] successive_outputs = [] @@ -3434,7 +3434,8 @@ def rnn(step_function, for i in range(time_steps): inp = _get_input_tensor(i) mask_t = mask_list[i] - output, new_states = step_function(inp, states + constants) + output, new_states = step_function(inp, + tuple(states) + tuple(constants)) tiled_mask_t = _expand_mask(mask_t, output) if not successive_outputs: @@ -3469,7 +3470,7 @@ def rnn(step_function, else: for i in range(time_steps): inp = _get_input_tensor(i) - output, states = step_function(inp, states + constants) + output, states = step_function(inp, tuple(states) + tuple(constants)) successive_outputs.append(output) successive_states.append(states) last_output = successive_outputs[-1] diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index e92df46d5ae..030908e51a2 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -28,8 +28,8 @@ from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_spec import InputSpec -from tensorflow.python.keras.layers.recurrent import _generate_dropout_mask from tensorflow.python.keras.layers.recurrent import _standardize_args +from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin from tensorflow.python.keras.layers.recurrent import RNN from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import generic_utils @@ -482,7 +482,7 @@ class ConvRNN2D(RNN): K.set_value(state, value) -class ConvLSTM2DCell(Layer): +class ConvLSTM2DCell(DropoutRNNCellMixin, Layer): """Cell class for the ConvLSTM2D layer. Arguments: @@ -597,8 +597,6 @@ class ConvLSTM2DCell(Layer): self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.state_size = (self.filters, self.filters) - self._dropout_mask = None - self._recurrent_dropout_mask = None def build(self, input_shape): @@ -648,28 +646,15 @@ class ConvLSTM2DCell(Layer): self.built = True def call(self, inputs, states, training=None): - if 0 < self.dropout < 1 and self._dropout_mask is None: - self._dropout_mask = _generate_dropout_mask( - K.ones_like(inputs), - self.dropout, - training=training, - count=4) - if (0 < self.recurrent_dropout < 1 and - self._recurrent_dropout_mask is None): - self._recurrent_dropout_mask = _generate_dropout_mask( - K.ones_like(states[1]), - self.recurrent_dropout, - training=training, - count=4) - - # dropout matrices for input units - dp_mask = self._dropout_mask - # dropout matrices for recurrent units - rec_dp_mask = self._recurrent_dropout_mask - h_tm1 = states[0] # previous memory state c_tm1 = states[1] # previous carry state + # dropout matrices for input units + dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) + # dropout matrices for recurrent units + rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( + h_tm1, training, count=4) + if 0 < self.dropout < 1.: inputs_i = inputs * dp_mask[0] inputs_f = inputs * dp_mask[1] @@ -945,6 +930,8 @@ class ConvLSTM2D(ConvRNN2D): self.activity_regularizer = regularizers.get(activity_regularizer) def call(self, inputs, mask=None, training=None, initial_state=None): + self.cell.reset_dropout_mask() + self.cell.reset_recurrent_dropout_mask() return super(ConvLSTM2D, self).call(inputs, mask=mask, training=training, diff --git a/tensorflow/python/keras/layers/convolutional_recurrent_test.py b/tensorflow/python/keras/layers/convolutional_recurrent_test.py index 916cd63303a..d0da360ef5f 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent_test.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent_test.py @@ -172,8 +172,6 @@ class ConvLSTMTest(keras_parameterized.TestCase): self.assertEqual(len(layer.losses), 4) def test_conv_lstm_dropout(self): - if testing_utils.should_run_eagerly(): - self.skipTest('Skip test due to b/126246383.') # check dropout with self.cached_session(): testing_utils.layer_test( diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index d7054515af9..a90415220a2 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -1107,8 +1107,136 @@ class AbstractRNNCell(Layer): return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) +class DropoutRNNCellMixin(object): + """Object that hold dropout related fields for RNN Cell. + + This class is not a standalone RNN cell. It suppose to be used with a RNN cell + by multiple inheritance. Any cell that mix with class should have following + fields: + dropout: a float number within range [0, 1). The ratio that the input + tensor need to dropout. + recurrent_dropout: a float number within range [0, 1). The ratio that the + recurrent state weights need to dropout. + This object will create and cache created dropout masks, and reuse them for + the incoming data, so that the same mask is used for every batch input. + """ + + def __init__(self, *args, **kwargs): + # Note that the following two masks will be used in "graph function" mode, + # e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask` + # tensors will be generated differently than in the "graph function" case, + # and they will be cached. + # Also note that in graph mode, we still cache those masks only because the + # RNN could be created with `unroll=True`. In that case, the `cell.call()` + # function will be invoked multiple times, and we want to ensure same mask + # is used every time. + self._dropout_mask = None + self._recurrent_dropout_mask = None + self._eager_dropout_mask = None + self._eager_recurrent_dropout_mask = None + super(DropoutRNNCellMixin, self).__init__(*args, **kwargs) + + def reset_dropout_mask(self): + """Reset the cached dropout masks if any. + + This is important for the RNN layer to invoke this in it call() method so + that the cached mask is cleared before calling the cell.call(). The mask + should be cached across the timestep within the same batch, but shouldn't + be cached between batches. Otherwise it will introduce unreasonable bias + against certain index of data within the batch. + """ + self._dropout_mask = None + self._eager_dropout_mask = None + + def reset_recurrent_dropout_mask(self): + """Reset the cached recurrent dropout masks if any. + + This is important for the RNN layer to invoke this in it call() method so + that the cached mask is cleared before calling the cell.call(). The mask + should be cached across the timestep within the same batch, but shouldn't + be cached between batches. Otherwise it will introduce unreasonable bias + against certain index of data within the batch. + """ + self._recurrent_dropout_mask = None + self._eager_recurrent_dropout_mask = None + + def get_dropout_mask_for_cell(self, inputs, training, count=1): + """Get the dropout mask for RNN cell's input. + + It will create mask based on context if there isn't any existing cached + mask. If a new mask is generated, it will update the cache in the cell. + + Args: + inputs: the input tensor whose shape will be used to generate dropout + mask. + training: boolean tensor, whether its in training mode, dropout will be + ignored in non-training mode. + count: int, how many dropout mask will be generated. It is useful for cell + that has internal weights fused together. + Returns: + List of mask tensor, generated or cached mask based on context. + """ + if self.dropout == 0: + return None + if (not context.executing_eagerly() and self._dropout_mask is None + or context.executing_eagerly() and self._eager_dropout_mask is None): + # Generate new mask and cache it based on context. + dp_mask = _generate_dropout_mask( + array_ops.ones_like(inputs), + self.dropout, + training=training, + count=count) + if context.executing_eagerly(): + self._eager_dropout_mask = dp_mask + else: + self._dropout_mask = dp_mask + else: + # Reuse the existing mask. + dp_mask = (self._eager_dropout_mask + if context.executing_eagerly() else self._dropout_mask) + return dp_mask + + def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1): + """Get the recurrent dropout mask for RNN cell. + + It will create mask based on context if there isn't any existing cached + mask. If a new mask is generated, it will update the cache in the cell. + + Args: + inputs: the input tensor whose shape will be used to generate dropout + mask. + training: boolean tensor, whether its in training mode, dropout will be + ignored in non-training mode. + count: int, how many dropout mask will be generated. It is useful for cell + that has internal weights fused together. + Returns: + List of mask tensor, generated or cached mask based on context. + """ + if self.recurrent_dropout == 0: + return None + if (not context.executing_eagerly() and self._recurrent_dropout_mask is None + or context.executing_eagerly() + and self._eager_recurrent_dropout_mask is None): + # Generate new mask and cache it based on context. + rec_dp_mask = _generate_dropout_mask( + array_ops.ones_like(inputs), + self.recurrent_dropout, + training=training, + count=count) + if context.executing_eagerly(): + self._eager_recurrent_dropout_mask = rec_dp_mask + else: + self._recurrent_dropout_mask = rec_dp_mask + else: + # Reuse the existing mask. + rec_dp_mask = (self._eager_recurrent_dropout_mask + if context.executing_eagerly() + else self._recurrent_dropout_mask) + return rec_dp_mask + + @keras_export('keras.layers.SimpleRNNCell') -class SimpleRNNCell(Layer): +class SimpleRNNCell(DropoutRNNCellMixin, Layer): """Cell class for SimpleRNN. Arguments: @@ -1185,8 +1313,6 @@ class SimpleRNNCell(Layer): self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.state_size = self.units self.output_size = self.units - self._dropout_mask = None - self._recurrent_dropout_mask = None @tf_utils.shape_type_conversion def build(self, input_shape): @@ -1215,20 +1341,9 @@ class SimpleRNNCell(Layer): def call(self, inputs, states, training=None): prev_output = states[0] - if 0 < self.dropout < 1 and self._dropout_mask is None: - self._dropout_mask = _generate_dropout_mask( - array_ops.ones_like(inputs), - self.dropout, - training=training) - if (0 < self.recurrent_dropout < 1 and - self._recurrent_dropout_mask is None): - self._recurrent_dropout_mask = _generate_dropout_mask( - array_ops.ones_like(prev_output), - self.recurrent_dropout, - training=training) - - dp_mask = self._dropout_mask - rec_dp_mask = self._recurrent_dropout_mask + dp_mask = self.get_dropout_mask_for_cell(inputs, training) + rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( + prev_output, training) if dp_mask is not None: h = K.dot(inputs * dp_mask, self.kernel) @@ -1401,8 +1516,8 @@ class SimpleRNN(RNN): self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): - self.cell._dropout_mask = None - self.cell._recurrent_dropout_mask = None + self.cell.reset_dropout_mask() + self.cell.reset_recurrent_dropout_mask() return super(SimpleRNN, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -1507,7 +1622,7 @@ class SimpleRNN(RNN): @keras_export('keras.layers.GRUCell') -class GRUCell(Layer): +class GRUCell(DropoutRNNCellMixin, Layer): """Cell class for the GRU layer. Arguments: @@ -1604,8 +1719,6 @@ class GRUCell(Layer): self.reset_after = reset_after self.state_size = self.units self.output_size = self.units - self._dropout_mask = None - self._recurrent_dropout_mask = None @tf_utils.shape_type_conversion def build(self, input_shape): @@ -1644,24 +1757,9 @@ class GRUCell(Layer): def call(self, inputs, states, training=None): h_tm1 = states[0] # previous memory - if 0 < self.dropout < 1 and self._dropout_mask is None: - self._dropout_mask = _generate_dropout_mask( - array_ops.ones_like(inputs), - self.dropout, - training=training, - count=3) - if (0 < self.recurrent_dropout < 1 and - self._recurrent_dropout_mask is None): - self._recurrent_dropout_mask = _generate_dropout_mask( - array_ops.ones_like(h_tm1), - self.recurrent_dropout, - training=training, - count=3) - - # dropout matrices for input units - dp_mask = self._dropout_mask - # dropout matrices for recurrent units - rec_dp_mask = self._recurrent_dropout_mask + dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) + rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( + h_tm1, training, count=3) if self.use_bias: if not self.reset_after: @@ -1938,8 +2036,8 @@ class GRU(RNN): self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): - self.cell._dropout_mask = None - self.cell._recurrent_dropout_mask = None + self.cell.reset_dropout_mask() + self.cell.reset_recurrent_dropout_mask() return super(GRU, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -2062,7 +2160,7 @@ class GRU(RNN): @keras_export('keras.layers.GRU', v1=[]) -class UnifiedGRU(GRU): +class UnifiedGRU(DropoutRNNCellMixin, GRU): """Gated Recurrent Unit - Cho et al. 2014. Based on available runtime hardware and constraints, this layer @@ -2222,7 +2320,6 @@ class UnifiedGRU(GRU): time_major=time_major, reset_after=reset_after, **kwargs) - self._dropout_mask = None # CuDNN uses following setting by default and not configurable. self.could_use_cudnn = ( activation == 'tanh' and recurrent_activation == 'sigmoid' and @@ -2242,8 +2339,6 @@ class UnifiedGRU(GRU): if mask is not None or not self.could_use_cudnn: # CuDNN does not support masking, fall back to use the normal GRU. kwargs = {'training': training} - self.cell._dropout_mask = None - self.cell._recurrent_dropout_mask = None def step(cell_inputs, cell_states): return self.cell.call(cell_inputs, cell_states, **kwargs) @@ -2288,15 +2383,11 @@ class UnifiedGRU(GRU): if self.go_backwards: # Reverse time axis. inputs = K.reverse(inputs, 0 if self.time_major else 1) - if 0 < self.dropout < 1: - if self._dropout_mask is None: - self._dropout_mask = _generate_dropout_mask( - array_ops.ones_like(inputs), - self.dropout, - training=training, - count=3) - inputs *= self._dropout_mask[0] + self.reset_dropout_mask() + dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) + if dropout_mask is not None: + inputs *= dropout_mask[0] if ops.executing_eagerly_outside_functions(): # Under eager context, the device placement is already known. Prefer the # GPU implementation when GPU is available. @@ -2457,7 +2548,7 @@ def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, time_major): @keras_export('keras.layers.LSTMCell') -class LSTMCell(Layer): +class LSTMCell(DropoutRNNCellMixin, Layer): """Cell class for the LSTM layer. Arguments: @@ -2557,8 +2648,6 @@ class LSTMCell(Layer): self.implementation = implementation self.state_size = [self.units, self.units] self.output_size = self.units - self._dropout_mask = None - self._recurrent_dropout_mask = None @tf_utils.shape_type_conversion def build(self, input_shape): @@ -2621,28 +2710,13 @@ class LSTMCell(Layer): return c, o def call(self, inputs, states, training=None): - if 0 < self.dropout < 1 and self._dropout_mask is None: - self._dropout_mask = _generate_dropout_mask( - array_ops.ones_like(inputs), - self.dropout, - training=training, - count=4) - if (0 < self.recurrent_dropout < 1 and - self._recurrent_dropout_mask is None): - self._recurrent_dropout_mask = _generate_dropout_mask( - array_ops.ones_like(states[0]), - self.recurrent_dropout, - training=training, - count=4) - - # dropout matrices for input units - dp_mask = self._dropout_mask - # dropout matrices for recurrent units - rec_dp_mask = self._recurrent_dropout_mask - h_tm1 = states[0] # previous memory state c_tm1 = states[1] # previous carry state + dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) + rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( + h_tm1, training, count=4) + if self.implementation == 1: if 0 < self.dropout < 1.: inputs_i = inputs * dp_mask[0] @@ -2967,8 +3041,8 @@ class LSTM(RNN): self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): - self.cell._dropout_mask = None - self.cell._recurrent_dropout_mask = None + self.cell.reset_dropout_mask() + self.cell.reset_recurrent_dropout_mask() return super(LSTM, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -3091,7 +3165,7 @@ class LSTM(RNN): @keras_export('keras.layers.LSTM', v1=[]) -class UnifiedLSTM(LSTM): +class UnifiedLSTM(DropoutRNNCellMixin, LSTM): """Long Short-Term Memory layer - Hochreiter 1997. Based on available runtime hardware and constraints, this layer @@ -3234,7 +3308,6 @@ class UnifiedLSTM(LSTM): self.state_spec = [ InputSpec(shape=(None, dim)) for dim in (self.units, self.units) ] - self._dropout_mask = None self.could_use_cudnn = ( activation == 'tanh' and recurrent_activation == 'sigmoid' and recurrent_dropout == 0 and not unroll and use_bias) @@ -3278,15 +3351,10 @@ class UnifiedLSTM(LSTM): # Reverse time axis. inputs = K.reverse(inputs, 0 if self.time_major else 1) - if 0 < self.dropout < 1: - if self._dropout_mask is None: - self._dropout_mask = _generate_dropout_mask( - array_ops.ones_like(inputs), - self.dropout, - training=training, - count=4) - - inputs *= self._dropout_mask[0] + self.reset_dropout_mask() + dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) + if dropout_mask is not None: + inputs *= dropout_mask[0] if ops.executing_eagerly_outside_functions(): # Under eager context, the device placement is already known. Prefer the diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py index 1d7b1c6898a..f248c4bdf35 100644 --- a/tensorflow/python/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -23,13 +23,16 @@ from __future__ import print_function import collections +from absl.testing import parameterized import numpy as np from tensorflow.python import keras from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.ops import array_ops @@ -725,14 +728,38 @@ class RNNTest(keras_parameterized.TestCase): y_np_2 = model.predict(x_np) self.assertAllClose(y_np, y_np_2, atol=1e-4) - def DISABLED_test_stacked_rnn_dropout(self): - # Temporarily disabled test due an occasional Grappler segfault. - # See b/115523414 - cells = [keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1), - keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)] - layer = keras.layers.RNN(cells) + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + layer=[keras.layers.SimpleRNN, keras.layers.GRU, keras.layers.LSTM, + keras.layers.UnifiedGRU, keras.layers.UnifiedLSTM], + unroll=[True, False])) + def test_rnn_dropout(self, layer, unroll): + rnn_layer = layer(3, dropout=0.1, recurrent_dropout=0.1, unroll=unroll) + if not unroll: + x = keras.Input((None, 5)) + else: + x = keras.Input((5, 5)) + y = rnn_layer(x) + model = keras.models.Model(x, y) + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + x_np = np.random.random((6, 5, 5)) + y_np = np.random.random((6, 3)) + model.train_on_batch(x_np, y_np) - x = keras.Input((None, 5)) + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + cell=[keras.layers.SimpleRNNCell, keras.layers.GRUCell, + keras.layers.LSTMCell], + unroll=[True, False])) + def test_stacked_rnn_dropout(self, cell, unroll): + cells = [cell(3, dropout=0.1, recurrent_dropout=0.1), + cell(3, dropout=0.1, recurrent_dropout=0.1)] + layer = keras.layers.RNN(cells, unroll=unroll) + + if not unroll: + x = keras.Input((None, 5)) + else: + x = keras.Input((5, 5)) y = layer(x) model = keras.models.Model(x, y) model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) @@ -740,6 +767,38 @@ class RNNTest(keras_parameterized.TestCase): y_np = np.random.random((6, 3)) model.train_on_batch(x_np, y_np) + def test_dropout_mask_reuse(self): + # The layer is created with recurrent_initializer = zero, so that the + # the recurrent state won't affect the output. By doing this, we can verify + # the output and see if the same mask is applied to for each timestep. + rnn = keras.layers.SimpleRNN(3, + dropout=0.5, + kernel_initializer='ones', + recurrent_initializer='zeros', + return_sequences=True, + unroll=True) + + inputs = constant_op.constant(1.0, shape=(6, 2, 5)) + out = rnn(inputs, training=True) + if not context.executing_eagerly(): + self.evaluate(variables_lib.global_variables_initializer()) + batch_1 = self.evaluate(out) + batch_1_t0, batch_1_t1 = batch_1[:, 0, :], batch_1[:, 1, :] + self.assertAllClose(batch_1_t0, batch_1_t1) + + # This simulate the layer called with multiple batches in eager mode + if context.executing_eagerly(): + out2 = rnn(inputs, training=True) + else: + out2 = out + batch_2 = self.evaluate(out2) + batch_2_t0, batch_2_t1 = batch_2[:, 0, :], batch_2[:, 1, :] + self.assertAllClose(batch_2_t0, batch_2_t1) + + # Also validate that different dropout is used by between batches. + self.assertNotAllClose(batch_1_t0, batch_2_t0) + self.assertNotAllClose(batch_1_t1, batch_2_t1) + def test_stacked_rnn_compute_output_shape(self): cells = [keras.layers.LSTMCell(3), keras.layers.LSTMCell(6)] diff --git a/tensorflow/python/keras/layers/unified_lstm_test.py b/tensorflow/python/keras/layers/unified_lstm_test.py index 938c87c6b1a..01089e1165f 100644 --- a/tensorflow/python/keras/layers/unified_lstm_test.py +++ b/tensorflow/python/keras/layers/unified_lstm_test.py @@ -654,6 +654,20 @@ class UnifiedLSTMTest(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly()) model.fit(x, y, epochs=1, shuffle=False) + def test_dropout_LSTM(self): + num_samples = 2 + timesteps = 3 + embedding_dim = 4 + units = 2 + testing_utils.layer_test( + keras.layers.UnifiedLSTM, + kwargs={ + 'units': units, + 'dropout': 0.1, + 'recurrent_dropout': 0.1 + }, + input_shape=(num_samples, timesteps, embedding_dim)) + class LSTMLayerGraphOnlyTest(test.TestCase): @@ -763,24 +777,6 @@ class LSTMLayerGraphOnlyTest(test.TestCase): existing_loss = loss_value -class LSTMLayerV1OnlyTest(test.TestCase, parameterized.TestCase): - - @test_util.run_in_graph_and_eager_modes(config=_config) - def test_dropout_LSTM(self): - num_samples = 2 - timesteps = 3 - embedding_dim = 4 - units = 2 - testing_utils.layer_test( - keras.layers.UnifiedLSTM, - kwargs={ - 'units': units, - 'dropout': 0.1, - 'recurrent_dropout': 0.1 - }, - input_shape=(num_samples, timesteps, embedding_dim)) - - class UnifiedLSTMPerformanceTest(test.Benchmark): def _measure_performance(self, test_config, model, x_train, y_train): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt index 2f3cb0b7c51..c3127642e25 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.experimental.PeepholeLSTMCell" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -141,6 +142,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -173,6 +178,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -181,6 +190,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt index dd93e32ddce..9a660083ee0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.GRUCell" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -140,6 +141,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -172,6 +177,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -180,6 +189,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 7398613812d..66aad25f9af 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.LSTMCell" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -140,6 +141,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -172,6 +177,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -180,6 +189,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 5e799329c03..33a0c1976b0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.SimpleRNNCell" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -140,6 +141,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -172,6 +177,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -180,6 +189,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt index 2f3cb0b7c51..c3127642e25 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.experimental.PeepholeLSTMCell" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -141,6 +142,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -173,6 +178,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -181,6 +190,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt index dd93e32ddce..9a660083ee0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.GRUCell" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -140,6 +141,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -172,6 +177,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -180,6 +189,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt index 32e69856b9e..fb89501bfc8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.GRU" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -214,6 +215,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -246,6 +251,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -254,6 +263,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "reset_states" argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 7398613812d..66aad25f9af 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.LSTMCell" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -140,6 +141,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -172,6 +177,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -180,6 +189,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt index c9b759d7927..aee27ad4d59 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.LSTM" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -214,6 +215,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -246,6 +251,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -254,6 +263,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "reset_states" argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 5e799329c03..33a0c1976b0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.SimpleRNNCell" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -140,6 +141,10 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -172,6 +177,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_recurrent_dropout_mask_for_cell" + argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " + } member_method { name: "get_updates_for" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -180,6 +189,14 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "reset_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_recurrent_dropout_mask" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"