Fix RNN cell wrt deep copy.
The root cause is that we have some internal cache field, which is not picklable by python. Fix https://github.com/tensorflow/tensorflow/issues/39978. PiperOrigin-RevId: 314265192 Change-Id: I66e80dea5fb65ac9c2e567f2af74510cd64ce5fc
This commit is contained in:
parent
1a71abed71
commit
3ab25c3203
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
@ -253,6 +255,12 @@ class GRULayerGenericTest(test.TestCase):
|
||||
l2 = layer_class.from_config(l1.get_config())
|
||||
assert l1.get_config() == l2.get_config()
|
||||
|
||||
def test_deep_copy_GRU(self):
|
||||
cell = keras.layers.GRUCell(5)
|
||||
copied_cell = copy.deepcopy(cell)
|
||||
self.assertEqual(copied_cell.units, 5)
|
||||
self.assertEqual(cell.get_config(), copied_cell.get_config())
|
||||
|
||||
def test_regularizers_GRU(self):
|
||||
embedding_dim = 4
|
||||
layer_class = keras.layers.GRU
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
@ -182,6 +184,12 @@ class LSTMLayerTest(keras_parameterized.TestCase):
|
||||
l2 = layer_class.from_config(l1.get_config())
|
||||
assert l1.get_config() == l2.get_config()
|
||||
|
||||
def test_deep_copy_LSTM(self):
|
||||
cell = keras.layers.LSTMCell(5)
|
||||
copied_cell = copy.deepcopy(cell)
|
||||
self.assertEqual(copied_cell.units, 5)
|
||||
self.assertEqual(cell.get_config(), copied_cell.get_config())
|
||||
|
||||
def test_specify_initial_state_keras_tensor(self):
|
||||
num_states = 2
|
||||
timesteps = 3
|
||||
|
@ -1096,18 +1096,30 @@ class DropoutRNNCellMixin(object):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Note that the following two masks will be used in "graph function" mode,
|
||||
# e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
|
||||
# tensors will be generated differently than in the "graph function" case,
|
||||
# and they will be cached.
|
||||
# Also note that in graph mode, we still cache those masks only because the
|
||||
# RNN could be created with `unroll=True`. In that case, the `cell.call()`
|
||||
# function will be invoked multiple times, and we want to ensure same mask
|
||||
# is used every time.
|
||||
self._create_non_trackable_mask_cache()
|
||||
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _create_non_trackable_mask_cache(self):
|
||||
"""Create the cache for dropout and recurrent dropout mask.
|
||||
|
||||
Note that the following two masks will be used in "graph function" mode,
|
||||
e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
|
||||
tensors will be generated differently than in the "graph function" case,
|
||||
and they will be cached.
|
||||
|
||||
Also note that in graph mode, we still cache those masks only because the
|
||||
RNN could be created with `unroll=True`. In that case, the `cell.call()`
|
||||
function will be invoked multiple times, and we want to ensure same mask
|
||||
is used every time.
|
||||
|
||||
Also the caches are created without tracking. Since they are not picklable
|
||||
by python when deepcopy, we don't want layer._obj_reference_counts_dict
|
||||
to track it by default.
|
||||
"""
|
||||
self._dropout_mask_cache = K.ContextValueCache(self._create_dropout_mask)
|
||||
self._recurrent_dropout_mask_cache = K.ContextValueCache(
|
||||
self._create_recurrent_dropout_mask)
|
||||
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
def reset_dropout_mask(self):
|
||||
"""Reset the cached dropout masks if any.
|
||||
@ -1187,6 +1199,21 @@ class DropoutRNNCellMixin(object):
|
||||
init_kwargs = dict(inputs=inputs, training=training, count=count)
|
||||
return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
# Used for deepcopy. The caching can't be pickled by python, since it will
|
||||
# contain tensor and graph.
|
||||
state = super(DropoutRNNCellMixin, self).__getstate__()
|
||||
state.pop('_dropout_mask_cache', None)
|
||||
state.pop('_recurrent_dropout_mask_cache', None)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
state['_dropout_mask_cache'] = K.ContextValueCache(
|
||||
self._create_dropout_mask)
|
||||
state['_recurrent_dropout_mask_cache'] = K.ContextValueCache(
|
||||
self._create_recurrent_dropout_mask)
|
||||
super(DropoutRNNCellMixin, self).__setstate__(state)
|
||||
|
||||
|
||||
@keras_export('keras.layers.SimpleRNNCell')
|
||||
class SimpleRNNCell(DropoutRNNCellMixin, Layer):
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
@ -133,6 +135,12 @@ class SimpleRNNLayerTest(test.TestCase, parameterized.TestCase):
|
||||
l2 = layer_class.from_config(l1.get_config())
|
||||
assert l1.get_config() == l2.get_config()
|
||||
|
||||
def test_deep_copy_SimpleRNN(self):
|
||||
cell = keras.layers.SimpleRNNCell(5)
|
||||
copied_cell = copy.deepcopy(cell)
|
||||
self.assertEqual(copied_cell.units, 5)
|
||||
self.assertEqual(cell.get_config(), copied_cell.get_config())
|
||||
|
||||
def test_regularizers_SimpleRNN(self):
|
||||
embedding_dim = 4
|
||||
layer_class = keras.layers.SimpleRNN
|
||||
|
Loading…
Reference in New Issue
Block a user