Fix RNN cell mask reuse in true eager mode.
PiperOrigin-RevId: 236008618
This commit is contained in:
parent
746397a4ed
commit
9650c977e3
@ -3402,7 +3402,7 @@ def rnn(step_function,
|
||||
if unroll:
|
||||
if not time_steps:
|
||||
raise ValueError('Unrolling requires a fixed number of timesteps.')
|
||||
states = initial_states
|
||||
states = tuple(initial_states)
|
||||
successive_states = []
|
||||
successive_outputs = []
|
||||
|
||||
@ -3434,7 +3434,8 @@ def rnn(step_function,
|
||||
for i in range(time_steps):
|
||||
inp = _get_input_tensor(i)
|
||||
mask_t = mask_list[i]
|
||||
output, new_states = step_function(inp, states + constants)
|
||||
output, new_states = step_function(inp,
|
||||
tuple(states) + tuple(constants))
|
||||
tiled_mask_t = _expand_mask(mask_t, output)
|
||||
|
||||
if not successive_outputs:
|
||||
@ -3469,7 +3470,7 @@ def rnn(step_function,
|
||||
else:
|
||||
for i in range(time_steps):
|
||||
inp = _get_input_tensor(i)
|
||||
output, states = step_function(inp, states + constants)
|
||||
output, states = step_function(inp, tuple(states) + tuple(constants))
|
||||
successive_outputs.append(output)
|
||||
successive_states.append(states)
|
||||
last_output = successive_outputs[-1]
|
||||
|
@ -28,8 +28,8 @@ from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||
from tensorflow.python.keras.layers.recurrent import _generate_dropout_mask
|
||||
from tensorflow.python.keras.layers.recurrent import _standardize_args
|
||||
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
|
||||
from tensorflow.python.keras.layers.recurrent import RNN
|
||||
from tensorflow.python.keras.utils import conv_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
@ -482,7 +482,7 @@ class ConvRNN2D(RNN):
|
||||
K.set_value(state, value)
|
||||
|
||||
|
||||
class ConvLSTM2DCell(Layer):
|
||||
class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
|
||||
"""Cell class for the ConvLSTM2D layer.
|
||||
|
||||
Arguments:
|
||||
@ -597,8 +597,6 @@ class ConvLSTM2DCell(Layer):
|
||||
self.dropout = min(1., max(0., dropout))
|
||||
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
|
||||
self.state_size = (self.filters, self.filters)
|
||||
self._dropout_mask = None
|
||||
self._recurrent_dropout_mask = None
|
||||
|
||||
def build(self, input_shape):
|
||||
|
||||
@ -648,28 +646,15 @@ class ConvLSTM2DCell(Layer):
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
if 0 < self.dropout < 1 and self._dropout_mask is None:
|
||||
self._dropout_mask = _generate_dropout_mask(
|
||||
K.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training,
|
||||
count=4)
|
||||
if (0 < self.recurrent_dropout < 1 and
|
||||
self._recurrent_dropout_mask is None):
|
||||
self._recurrent_dropout_mask = _generate_dropout_mask(
|
||||
K.ones_like(states[1]),
|
||||
self.recurrent_dropout,
|
||||
training=training,
|
||||
count=4)
|
||||
|
||||
# dropout matrices for input units
|
||||
dp_mask = self._dropout_mask
|
||||
# dropout matrices for recurrent units
|
||||
rec_dp_mask = self._recurrent_dropout_mask
|
||||
|
||||
h_tm1 = states[0] # previous memory state
|
||||
c_tm1 = states[1] # previous carry state
|
||||
|
||||
# dropout matrices for input units
|
||||
dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
|
||||
# dropout matrices for recurrent units
|
||||
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
|
||||
h_tm1, training, count=4)
|
||||
|
||||
if 0 < self.dropout < 1.:
|
||||
inputs_i = inputs * dp_mask[0]
|
||||
inputs_f = inputs * dp_mask[1]
|
||||
@ -945,6 +930,8 @@ class ConvLSTM2D(ConvRNN2D):
|
||||
self.activity_regularizer = regularizers.get(activity_regularizer)
|
||||
|
||||
def call(self, inputs, mask=None, training=None, initial_state=None):
|
||||
self.cell.reset_dropout_mask()
|
||||
self.cell.reset_recurrent_dropout_mask()
|
||||
return super(ConvLSTM2D, self).call(inputs,
|
||||
mask=mask,
|
||||
training=training,
|
||||
|
@ -172,8 +172,6 @@ class ConvLSTMTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(len(layer.losses), 4)
|
||||
|
||||
def test_conv_lstm_dropout(self):
|
||||
if testing_utils.should_run_eagerly():
|
||||
self.skipTest('Skip test due to b/126246383.')
|
||||
# check dropout
|
||||
with self.cached_session():
|
||||
testing_utils.layer_test(
|
||||
|
@ -1107,8 +1107,136 @@ class AbstractRNNCell(Layer):
|
||||
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
|
||||
|
||||
|
||||
class DropoutRNNCellMixin(object):
|
||||
"""Object that hold dropout related fields for RNN Cell.
|
||||
|
||||
This class is not a standalone RNN cell. It suppose to be used with a RNN cell
|
||||
by multiple inheritance. Any cell that mix with class should have following
|
||||
fields:
|
||||
dropout: a float number within range [0, 1). The ratio that the input
|
||||
tensor need to dropout.
|
||||
recurrent_dropout: a float number within range [0, 1). The ratio that the
|
||||
recurrent state weights need to dropout.
|
||||
This object will create and cache created dropout masks, and reuse them for
|
||||
the incoming data, so that the same mask is used for every batch input.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Note that the following two masks will be used in "graph function" mode,
|
||||
# e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
|
||||
# tensors will be generated differently than in the "graph function" case,
|
||||
# and they will be cached.
|
||||
# Also note that in graph mode, we still cache those masks only because the
|
||||
# RNN could be created with `unroll=True`. In that case, the `cell.call()`
|
||||
# function will be invoked multiple times, and we want to ensure same mask
|
||||
# is used every time.
|
||||
self._dropout_mask = None
|
||||
self._recurrent_dropout_mask = None
|
||||
self._eager_dropout_mask = None
|
||||
self._eager_recurrent_dropout_mask = None
|
||||
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
def reset_dropout_mask(self):
|
||||
"""Reset the cached dropout masks if any.
|
||||
|
||||
This is important for the RNN layer to invoke this in it call() method so
|
||||
that the cached mask is cleared before calling the cell.call(). The mask
|
||||
should be cached across the timestep within the same batch, but shouldn't
|
||||
be cached between batches. Otherwise it will introduce unreasonable bias
|
||||
against certain index of data within the batch.
|
||||
"""
|
||||
self._dropout_mask = None
|
||||
self._eager_dropout_mask = None
|
||||
|
||||
def reset_recurrent_dropout_mask(self):
|
||||
"""Reset the cached recurrent dropout masks if any.
|
||||
|
||||
This is important for the RNN layer to invoke this in it call() method so
|
||||
that the cached mask is cleared before calling the cell.call(). The mask
|
||||
should be cached across the timestep within the same batch, but shouldn't
|
||||
be cached between batches. Otherwise it will introduce unreasonable bias
|
||||
against certain index of data within the batch.
|
||||
"""
|
||||
self._recurrent_dropout_mask = None
|
||||
self._eager_recurrent_dropout_mask = None
|
||||
|
||||
def get_dropout_mask_for_cell(self, inputs, training, count=1):
|
||||
"""Get the dropout mask for RNN cell's input.
|
||||
|
||||
It will create mask based on context if there isn't any existing cached
|
||||
mask. If a new mask is generated, it will update the cache in the cell.
|
||||
|
||||
Args:
|
||||
inputs: the input tensor whose shape will be used to generate dropout
|
||||
mask.
|
||||
training: boolean tensor, whether its in training mode, dropout will be
|
||||
ignored in non-training mode.
|
||||
count: int, how many dropout mask will be generated. It is useful for cell
|
||||
that has internal weights fused together.
|
||||
Returns:
|
||||
List of mask tensor, generated or cached mask based on context.
|
||||
"""
|
||||
if self.dropout == 0:
|
||||
return None
|
||||
if (not context.executing_eagerly() and self._dropout_mask is None
|
||||
or context.executing_eagerly() and self._eager_dropout_mask is None):
|
||||
# Generate new mask and cache it based on context.
|
||||
dp_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training,
|
||||
count=count)
|
||||
if context.executing_eagerly():
|
||||
self._eager_dropout_mask = dp_mask
|
||||
else:
|
||||
self._dropout_mask = dp_mask
|
||||
else:
|
||||
# Reuse the existing mask.
|
||||
dp_mask = (self._eager_dropout_mask
|
||||
if context.executing_eagerly() else self._dropout_mask)
|
||||
return dp_mask
|
||||
|
||||
def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
|
||||
"""Get the recurrent dropout mask for RNN cell.
|
||||
|
||||
It will create mask based on context if there isn't any existing cached
|
||||
mask. If a new mask is generated, it will update the cache in the cell.
|
||||
|
||||
Args:
|
||||
inputs: the input tensor whose shape will be used to generate dropout
|
||||
mask.
|
||||
training: boolean tensor, whether its in training mode, dropout will be
|
||||
ignored in non-training mode.
|
||||
count: int, how many dropout mask will be generated. It is useful for cell
|
||||
that has internal weights fused together.
|
||||
Returns:
|
||||
List of mask tensor, generated or cached mask based on context.
|
||||
"""
|
||||
if self.recurrent_dropout == 0:
|
||||
return None
|
||||
if (not context.executing_eagerly() and self._recurrent_dropout_mask is None
|
||||
or context.executing_eagerly()
|
||||
and self._eager_recurrent_dropout_mask is None):
|
||||
# Generate new mask and cache it based on context.
|
||||
rec_dp_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.recurrent_dropout,
|
||||
training=training,
|
||||
count=count)
|
||||
if context.executing_eagerly():
|
||||
self._eager_recurrent_dropout_mask = rec_dp_mask
|
||||
else:
|
||||
self._recurrent_dropout_mask = rec_dp_mask
|
||||
else:
|
||||
# Reuse the existing mask.
|
||||
rec_dp_mask = (self._eager_recurrent_dropout_mask
|
||||
if context.executing_eagerly()
|
||||
else self._recurrent_dropout_mask)
|
||||
return rec_dp_mask
|
||||
|
||||
|
||||
@keras_export('keras.layers.SimpleRNNCell')
|
||||
class SimpleRNNCell(Layer):
|
||||
class SimpleRNNCell(DropoutRNNCellMixin, Layer):
|
||||
"""Cell class for SimpleRNN.
|
||||
|
||||
Arguments:
|
||||
@ -1185,8 +1313,6 @@ class SimpleRNNCell(Layer):
|
||||
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
|
||||
self.state_size = self.units
|
||||
self.output_size = self.units
|
||||
self._dropout_mask = None
|
||||
self._recurrent_dropout_mask = None
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def build(self, input_shape):
|
||||
@ -1215,20 +1341,9 @@ class SimpleRNNCell(Layer):
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
prev_output = states[0]
|
||||
if 0 < self.dropout < 1 and self._dropout_mask is None:
|
||||
self._dropout_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training)
|
||||
if (0 < self.recurrent_dropout < 1 and
|
||||
self._recurrent_dropout_mask is None):
|
||||
self._recurrent_dropout_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(prev_output),
|
||||
self.recurrent_dropout,
|
||||
training=training)
|
||||
|
||||
dp_mask = self._dropout_mask
|
||||
rec_dp_mask = self._recurrent_dropout_mask
|
||||
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
|
||||
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
|
||||
prev_output, training)
|
||||
|
||||
if dp_mask is not None:
|
||||
h = K.dot(inputs * dp_mask, self.kernel)
|
||||
@ -1401,8 +1516,8 @@ class SimpleRNN(RNN):
|
||||
self.input_spec = [InputSpec(ndim=3)]
|
||||
|
||||
def call(self, inputs, mask=None, training=None, initial_state=None):
|
||||
self.cell._dropout_mask = None
|
||||
self.cell._recurrent_dropout_mask = None
|
||||
self.cell.reset_dropout_mask()
|
||||
self.cell.reset_recurrent_dropout_mask()
|
||||
return super(SimpleRNN, self).call(
|
||||
inputs, mask=mask, training=training, initial_state=initial_state)
|
||||
|
||||
@ -1507,7 +1622,7 @@ class SimpleRNN(RNN):
|
||||
|
||||
|
||||
@keras_export('keras.layers.GRUCell')
|
||||
class GRUCell(Layer):
|
||||
class GRUCell(DropoutRNNCellMixin, Layer):
|
||||
"""Cell class for the GRU layer.
|
||||
|
||||
Arguments:
|
||||
@ -1604,8 +1719,6 @@ class GRUCell(Layer):
|
||||
self.reset_after = reset_after
|
||||
self.state_size = self.units
|
||||
self.output_size = self.units
|
||||
self._dropout_mask = None
|
||||
self._recurrent_dropout_mask = None
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def build(self, input_shape):
|
||||
@ -1644,24 +1757,9 @@ class GRUCell(Layer):
|
||||
def call(self, inputs, states, training=None):
|
||||
h_tm1 = states[0] # previous memory
|
||||
|
||||
if 0 < self.dropout < 1 and self._dropout_mask is None:
|
||||
self._dropout_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training,
|
||||
count=3)
|
||||
if (0 < self.recurrent_dropout < 1 and
|
||||
self._recurrent_dropout_mask is None):
|
||||
self._recurrent_dropout_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(h_tm1),
|
||||
self.recurrent_dropout,
|
||||
training=training,
|
||||
count=3)
|
||||
|
||||
# dropout matrices for input units
|
||||
dp_mask = self._dropout_mask
|
||||
# dropout matrices for recurrent units
|
||||
rec_dp_mask = self._recurrent_dropout_mask
|
||||
dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
|
||||
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
|
||||
h_tm1, training, count=3)
|
||||
|
||||
if self.use_bias:
|
||||
if not self.reset_after:
|
||||
@ -1938,8 +2036,8 @@ class GRU(RNN):
|
||||
self.input_spec = [InputSpec(ndim=3)]
|
||||
|
||||
def call(self, inputs, mask=None, training=None, initial_state=None):
|
||||
self.cell._dropout_mask = None
|
||||
self.cell._recurrent_dropout_mask = None
|
||||
self.cell.reset_dropout_mask()
|
||||
self.cell.reset_recurrent_dropout_mask()
|
||||
return super(GRU, self).call(
|
||||
inputs, mask=mask, training=training, initial_state=initial_state)
|
||||
|
||||
@ -2062,7 +2160,7 @@ class GRU(RNN):
|
||||
|
||||
|
||||
@keras_export('keras.layers.GRU', v1=[])
|
||||
class UnifiedGRU(GRU):
|
||||
class UnifiedGRU(DropoutRNNCellMixin, GRU):
|
||||
"""Gated Recurrent Unit - Cho et al. 2014.
|
||||
|
||||
Based on available runtime hardware and constraints, this layer
|
||||
@ -2222,7 +2320,6 @@ class UnifiedGRU(GRU):
|
||||
time_major=time_major,
|
||||
reset_after=reset_after,
|
||||
**kwargs)
|
||||
self._dropout_mask = None
|
||||
# CuDNN uses following setting by default and not configurable.
|
||||
self.could_use_cudnn = (
|
||||
activation == 'tanh' and recurrent_activation == 'sigmoid' and
|
||||
@ -2242,8 +2339,6 @@ class UnifiedGRU(GRU):
|
||||
if mask is not None or not self.could_use_cudnn:
|
||||
# CuDNN does not support masking, fall back to use the normal GRU.
|
||||
kwargs = {'training': training}
|
||||
self.cell._dropout_mask = None
|
||||
self.cell._recurrent_dropout_mask = None
|
||||
|
||||
def step(cell_inputs, cell_states):
|
||||
return self.cell.call(cell_inputs, cell_states, **kwargs)
|
||||
@ -2288,15 +2383,11 @@ class UnifiedGRU(GRU):
|
||||
if self.go_backwards:
|
||||
# Reverse time axis.
|
||||
inputs = K.reverse(inputs, 0 if self.time_major else 1)
|
||||
if 0 < self.dropout < 1:
|
||||
if self._dropout_mask is None:
|
||||
self._dropout_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training,
|
||||
count=3)
|
||||
|
||||
inputs *= self._dropout_mask[0]
|
||||
self.reset_dropout_mask()
|
||||
dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
|
||||
if dropout_mask is not None:
|
||||
inputs *= dropout_mask[0]
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# Under eager context, the device placement is already known. Prefer the
|
||||
# GPU implementation when GPU is available.
|
||||
@ -2457,7 +2548,7 @@ def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, time_major):
|
||||
|
||||
|
||||
@keras_export('keras.layers.LSTMCell')
|
||||
class LSTMCell(Layer):
|
||||
class LSTMCell(DropoutRNNCellMixin, Layer):
|
||||
"""Cell class for the LSTM layer.
|
||||
|
||||
Arguments:
|
||||
@ -2557,8 +2648,6 @@ class LSTMCell(Layer):
|
||||
self.implementation = implementation
|
||||
self.state_size = [self.units, self.units]
|
||||
self.output_size = self.units
|
||||
self._dropout_mask = None
|
||||
self._recurrent_dropout_mask = None
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def build(self, input_shape):
|
||||
@ -2621,28 +2710,13 @@ class LSTMCell(Layer):
|
||||
return c, o
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
if 0 < self.dropout < 1 and self._dropout_mask is None:
|
||||
self._dropout_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training,
|
||||
count=4)
|
||||
if (0 < self.recurrent_dropout < 1 and
|
||||
self._recurrent_dropout_mask is None):
|
||||
self._recurrent_dropout_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(states[0]),
|
||||
self.recurrent_dropout,
|
||||
training=training,
|
||||
count=4)
|
||||
|
||||
# dropout matrices for input units
|
||||
dp_mask = self._dropout_mask
|
||||
# dropout matrices for recurrent units
|
||||
rec_dp_mask = self._recurrent_dropout_mask
|
||||
|
||||
h_tm1 = states[0] # previous memory state
|
||||
c_tm1 = states[1] # previous carry state
|
||||
|
||||
dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
|
||||
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
|
||||
h_tm1, training, count=4)
|
||||
|
||||
if self.implementation == 1:
|
||||
if 0 < self.dropout < 1.:
|
||||
inputs_i = inputs * dp_mask[0]
|
||||
@ -2967,8 +3041,8 @@ class LSTM(RNN):
|
||||
self.input_spec = [InputSpec(ndim=3)]
|
||||
|
||||
def call(self, inputs, mask=None, training=None, initial_state=None):
|
||||
self.cell._dropout_mask = None
|
||||
self.cell._recurrent_dropout_mask = None
|
||||
self.cell.reset_dropout_mask()
|
||||
self.cell.reset_recurrent_dropout_mask()
|
||||
return super(LSTM, self).call(
|
||||
inputs, mask=mask, training=training, initial_state=initial_state)
|
||||
|
||||
@ -3091,7 +3165,7 @@ class LSTM(RNN):
|
||||
|
||||
|
||||
@keras_export('keras.layers.LSTM', v1=[])
|
||||
class UnifiedLSTM(LSTM):
|
||||
class UnifiedLSTM(DropoutRNNCellMixin, LSTM):
|
||||
"""Long Short-Term Memory layer - Hochreiter 1997.
|
||||
|
||||
Based on available runtime hardware and constraints, this layer
|
||||
@ -3234,7 +3308,6 @@ class UnifiedLSTM(LSTM):
|
||||
self.state_spec = [
|
||||
InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
|
||||
]
|
||||
self._dropout_mask = None
|
||||
self.could_use_cudnn = (
|
||||
activation == 'tanh' and recurrent_activation == 'sigmoid' and
|
||||
recurrent_dropout == 0 and not unroll and use_bias)
|
||||
@ -3278,15 +3351,10 @@ class UnifiedLSTM(LSTM):
|
||||
# Reverse time axis.
|
||||
inputs = K.reverse(inputs, 0 if self.time_major else 1)
|
||||
|
||||
if 0 < self.dropout < 1:
|
||||
if self._dropout_mask is None:
|
||||
self._dropout_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training,
|
||||
count=4)
|
||||
|
||||
inputs *= self._dropout_mask[0]
|
||||
self.reset_dropout_mask()
|
||||
dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
|
||||
if dropout_mask is not None:
|
||||
inputs *= dropout_mask[0]
|
||||
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# Under eager context, the device placement is already known. Prefer the
|
||||
|
@ -23,13 +23,16 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -725,14 +728,38 @@ class RNNTest(keras_parameterized.TestCase):
|
||||
y_np_2 = model.predict(x_np)
|
||||
self.assertAllClose(y_np, y_np_2, atol=1e-4)
|
||||
|
||||
def DISABLED_test_stacked_rnn_dropout(self):
|
||||
# Temporarily disabled test due an occasional Grappler segfault.
|
||||
# See b/115523414
|
||||
cells = [keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
|
||||
keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
|
||||
layer = keras.layers.RNN(cells)
|
||||
@parameterized.named_parameters(
|
||||
*test_util.generate_combinations_with_testcase_name(
|
||||
layer=[keras.layers.SimpleRNN, keras.layers.GRU, keras.layers.LSTM,
|
||||
keras.layers.UnifiedGRU, keras.layers.UnifiedLSTM],
|
||||
unroll=[True, False]))
|
||||
def test_rnn_dropout(self, layer, unroll):
|
||||
rnn_layer = layer(3, dropout=0.1, recurrent_dropout=0.1, unroll=unroll)
|
||||
if not unroll:
|
||||
x = keras.Input((None, 5))
|
||||
else:
|
||||
x = keras.Input((5, 5))
|
||||
y = rnn_layer(x)
|
||||
model = keras.models.Model(x, y)
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
x_np = np.random.random((6, 5, 5))
|
||||
y_np = np.random.random((6, 3))
|
||||
model.train_on_batch(x_np, y_np)
|
||||
|
||||
x = keras.Input((None, 5))
|
||||
@parameterized.named_parameters(
|
||||
*test_util.generate_combinations_with_testcase_name(
|
||||
cell=[keras.layers.SimpleRNNCell, keras.layers.GRUCell,
|
||||
keras.layers.LSTMCell],
|
||||
unroll=[True, False]))
|
||||
def test_stacked_rnn_dropout(self, cell, unroll):
|
||||
cells = [cell(3, dropout=0.1, recurrent_dropout=0.1),
|
||||
cell(3, dropout=0.1, recurrent_dropout=0.1)]
|
||||
layer = keras.layers.RNN(cells, unroll=unroll)
|
||||
|
||||
if not unroll:
|
||||
x = keras.Input((None, 5))
|
||||
else:
|
||||
x = keras.Input((5, 5))
|
||||
y = layer(x)
|
||||
model = keras.models.Model(x, y)
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
@ -740,6 +767,38 @@ class RNNTest(keras_parameterized.TestCase):
|
||||
y_np = np.random.random((6, 3))
|
||||
model.train_on_batch(x_np, y_np)
|
||||
|
||||
def test_dropout_mask_reuse(self):
|
||||
# The layer is created with recurrent_initializer = zero, so that the
|
||||
# the recurrent state won't affect the output. By doing this, we can verify
|
||||
# the output and see if the same mask is applied to for each timestep.
|
||||
rnn = keras.layers.SimpleRNN(3,
|
||||
dropout=0.5,
|
||||
kernel_initializer='ones',
|
||||
recurrent_initializer='zeros',
|
||||
return_sequences=True,
|
||||
unroll=True)
|
||||
|
||||
inputs = constant_op.constant(1.0, shape=(6, 2, 5))
|
||||
out = rnn(inputs, training=True)
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
batch_1 = self.evaluate(out)
|
||||
batch_1_t0, batch_1_t1 = batch_1[:, 0, :], batch_1[:, 1, :]
|
||||
self.assertAllClose(batch_1_t0, batch_1_t1)
|
||||
|
||||
# This simulate the layer called with multiple batches in eager mode
|
||||
if context.executing_eagerly():
|
||||
out2 = rnn(inputs, training=True)
|
||||
else:
|
||||
out2 = out
|
||||
batch_2 = self.evaluate(out2)
|
||||
batch_2_t0, batch_2_t1 = batch_2[:, 0, :], batch_2[:, 1, :]
|
||||
self.assertAllClose(batch_2_t0, batch_2_t1)
|
||||
|
||||
# Also validate that different dropout is used by between batches.
|
||||
self.assertNotAllClose(batch_1_t0, batch_2_t0)
|
||||
self.assertNotAllClose(batch_1_t1, batch_2_t1)
|
||||
|
||||
def test_stacked_rnn_compute_output_shape(self):
|
||||
cells = [keras.layers.LSTMCell(3),
|
||||
keras.layers.LSTMCell(6)]
|
||||
|
@ -654,6 +654,20 @@ class UnifiedLSTMTest(keras_parameterized.TestCase):
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(x, y, epochs=1, shuffle=False)
|
||||
|
||||
def test_dropout_LSTM(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
testing_utils.layer_test(
|
||||
keras.layers.UnifiedLSTM,
|
||||
kwargs={
|
||||
'units': units,
|
||||
'dropout': 0.1,
|
||||
'recurrent_dropout': 0.1
|
||||
},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
|
||||
class LSTMLayerGraphOnlyTest(test.TestCase):
|
||||
|
||||
@ -763,24 +777,6 @@ class LSTMLayerGraphOnlyTest(test.TestCase):
|
||||
existing_loss = loss_value
|
||||
|
||||
|
||||
class LSTMLayerV1OnlyTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=_config)
|
||||
def test_dropout_LSTM(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
testing_utils.layer_test(
|
||||
keras.layers.UnifiedLSTM,
|
||||
kwargs={
|
||||
'units': units,
|
||||
'dropout': 0.1,
|
||||
'recurrent_dropout': 0.1
|
||||
},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
|
||||
class UnifiedLSTMPerformanceTest(test.Benchmark):
|
||||
|
||||
def _measure_performance(self, test_config, model, x_train, y_train):
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.keras.experimental.PeepholeLSTMCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.PeepholeLSTMCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.LSTMCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -141,6 +142,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
@ -173,6 +178,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -181,6 +190,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.GRUCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.GRUCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -140,6 +141,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
@ -172,6 +177,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -180,6 +189,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.LSTMCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.LSTMCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -140,6 +141,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
@ -172,6 +177,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -180,6 +189,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.SimpleRNNCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.SimpleRNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -140,6 +141,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
@ -172,6 +177,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -180,6 +189,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.keras.experimental.PeepholeLSTMCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.PeepholeLSTMCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.LSTMCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -141,6 +142,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
@ -173,6 +178,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -181,6 +190,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.GRUCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.GRUCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -140,6 +141,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
@ -172,6 +177,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -180,6 +189,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.GRU"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.UnifiedGRU\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.GRU\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.RNN\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
@ -214,6 +215,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -246,6 +251,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -254,6 +263,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_states"
|
||||
argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.LSTMCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.LSTMCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -140,6 +141,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
@ -172,6 +177,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -180,6 +189,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.LSTM"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.UnifiedLSTM\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.LSTM\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.RNN\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
@ -214,6 +215,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -246,6 +251,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -254,6 +263,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_states"
|
||||
argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.SimpleRNNCell"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.SimpleRNNCell\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.DropoutRNNCellMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
@ -140,6 +141,10 @@ tf_class {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_initial_state"
|
||||
argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
@ -172,6 +177,10 @@ tf_class {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_recurrent_dropout_mask_for_cell"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -180,6 +189,14 @@ tf_class {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_recurrent_dropout_mask"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
x
Reference in New Issue
Block a user