Fix RNN cell wrt deep copy.

The root cause is that we have some internal cache field, which is not picklable by python.

Fix https://github.com/tensorflow/tensorflow/issues/39978.

PiperOrigin-RevId: 314265192
Change-Id: I66e80dea5fb65ac9c2e567f2af74510cd64ce5fc
This commit is contained in:
Scott Zhu 2020-06-01 21:52:21 -07:00 committed by TensorFlower Gardener
parent 1a71abed71
commit 3ab25c3203
4 changed files with 60 additions and 9 deletions

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from absl.testing import parameterized
import numpy as np
@ -253,6 +255,12 @@ class GRULayerGenericTest(test.TestCase):
l2 = layer_class.from_config(l1.get_config())
assert l1.get_config() == l2.get_config()
def test_deep_copy_GRU(self):
cell = keras.layers.GRUCell(5)
copied_cell = copy.deepcopy(cell)
self.assertEqual(copied_cell.units, 5)
self.assertEqual(cell.get_config(), copied_cell.get_config())
def test_regularizers_GRU(self):
embedding_dim = 4
layer_class = keras.layers.GRU

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from absl.testing import parameterized
import numpy as np
@ -182,6 +184,12 @@ class LSTMLayerTest(keras_parameterized.TestCase):
l2 = layer_class.from_config(l1.get_config())
assert l1.get_config() == l2.get_config()
def test_deep_copy_LSTM(self):
cell = keras.layers.LSTMCell(5)
copied_cell = copy.deepcopy(cell)
self.assertEqual(copied_cell.units, 5)
self.assertEqual(cell.get_config(), copied_cell.get_config())
def test_specify_initial_state_keras_tensor(self):
num_states = 2
timesteps = 3

View File

@ -1096,18 +1096,30 @@ class DropoutRNNCellMixin(object):
"""
def __init__(self, *args, **kwargs):
# Note that the following two masks will be used in "graph function" mode,
# e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
# tensors will be generated differently than in the "graph function" case,
# and they will be cached.
# Also note that in graph mode, we still cache those masks only because the
# RNN could be created with `unroll=True`. In that case, the `cell.call()`
# function will be invoked multiple times, and we want to ensure same mask
# is used every time.
self._create_non_trackable_mask_cache()
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
@trackable.no_automatic_dependency_tracking
def _create_non_trackable_mask_cache(self):
"""Create the cache for dropout and recurrent dropout mask.
Note that the following two masks will be used in "graph function" mode,
e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
tensors will be generated differently than in the "graph function" case,
and they will be cached.
Also note that in graph mode, we still cache those masks only because the
RNN could be created with `unroll=True`. In that case, the `cell.call()`
function will be invoked multiple times, and we want to ensure same mask
is used every time.
Also the caches are created without tracking. Since they are not picklable
by python when deepcopy, we don't want layer._obj_reference_counts_dict
to track it by default.
"""
self._dropout_mask_cache = K.ContextValueCache(self._create_dropout_mask)
self._recurrent_dropout_mask_cache = K.ContextValueCache(
self._create_recurrent_dropout_mask)
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
def reset_dropout_mask(self):
"""Reset the cached dropout masks if any.
@ -1187,6 +1199,21 @@ class DropoutRNNCellMixin(object):
init_kwargs = dict(inputs=inputs, training=training, count=count)
return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
def __getstate__(self):
# Used for deepcopy. The caching can't be pickled by python, since it will
# contain tensor and graph.
state = super(DropoutRNNCellMixin, self).__getstate__()
state.pop('_dropout_mask_cache', None)
state.pop('_recurrent_dropout_mask_cache', None)
return state
def __setstate__(self, state):
state['_dropout_mask_cache'] = K.ContextValueCache(
self._create_dropout_mask)
state['_recurrent_dropout_mask_cache'] = K.ContextValueCache(
self._create_recurrent_dropout_mask)
super(DropoutRNNCellMixin, self).__setstate__(state)
@keras_export('keras.layers.SimpleRNNCell')
class SimpleRNNCell(DropoutRNNCellMixin, Layer):

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from absl.testing import parameterized
import numpy as np
@ -133,6 +135,12 @@ class SimpleRNNLayerTest(test.TestCase, parameterized.TestCase):
l2 = layer_class.from_config(l1.get_config())
assert l1.get_config() == l2.get_config()
def test_deep_copy_SimpleRNN(self):
cell = keras.layers.SimpleRNNCell(5)
copied_cell = copy.deepcopy(cell)
self.assertEqual(copied_cell.units, 5)
self.assertEqual(cell.get_config(), copied_cell.get_config())
def test_regularizers_SimpleRNN(self):
embedding_dim = 4
layer_class = keras.layers.SimpleRNN