Removing wrapper._input_map field to avoid potential memory leak.
Fix https://github.com/tensorflow/tensorflow/issues/33178. PiperOrigin-RevId: 292043221 Change-Id: Ife2fa9a2adf50424bb2c932044fe3db5f4bb42d5
This commit is contained in:
parent
c3f78d1c1c
commit
d064c6fc9a
@ -51,10 +51,6 @@ class Wrapper(Layer):
|
||||
def __init__(self, layer, **kwargs):
|
||||
assert isinstance(layer, Layer)
|
||||
self.layer = layer
|
||||
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
|
||||
# the inner layer has update ops that depend on its inputs (as opposed
|
||||
# to the inputs to the Wrapper layer).
|
||||
self._input_map = {}
|
||||
super(Wrapper, self).__init__(**kwargs)
|
||||
|
||||
def build(self, input_shape=None):
|
||||
@ -258,9 +254,7 @@ class TimeDistributed(Wrapper):
|
||||
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
|
||||
# Shape: (num_samples * timesteps, ...). And track the
|
||||
# transformation in self._input_map.
|
||||
input_uid = generic_utils.object_list_uid(inputs)
|
||||
inputs = array_ops.reshape(inputs, inner_input_shape)
|
||||
self._input_map[input_uid] = inputs
|
||||
# (num_samples * timesteps, ...)
|
||||
if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
|
||||
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
|
||||
@ -314,15 +308,17 @@ class TimeDistributed(Wrapper):
|
||||
# cases need to call the layer.compute_mask when input_mask is None:
|
||||
# Masking layer and Embedding layer with mask_zero
|
||||
input_shape = K.int_shape(inputs)
|
||||
if input_shape[0]:
|
||||
# batch size matters, we currently do not handle mask explicitly
|
||||
if input_shape[0] and not self._always_use_reshape or isinstance(
|
||||
inputs, ragged_tensor.RaggedTensor):
|
||||
# batch size matters, we currently do not handle mask explicitly, or if
|
||||
# the layer always uses reshape approach, or the input is a ragged tensor.
|
||||
return mask
|
||||
inner_mask = mask
|
||||
if inner_mask is not None:
|
||||
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
|
||||
inner_mask = K.reshape(inner_mask, inner_mask_shape)
|
||||
input_uid = generic_utils.object_list_uid(inputs)
|
||||
inner_inputs = self._input_map.get(input_uid, inputs)
|
||||
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
|
||||
inner_inputs = array_ops.reshape(inputs, inner_input_shape)
|
||||
output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
|
||||
if output_mask is None:
|
||||
if mask is None:
|
||||
|
@ -186,7 +186,6 @@ class TimeDistributedTest(keras_parameterized.TestCase):
|
||||
y = model.predict(np.random.random((10, 3, 2)))
|
||||
self.assertAllClose(np.mean(y), 0., atol=1e-1, rtol=1e-1)
|
||||
|
||||
@tf_test_util.run_v1_only(reason='b/148248386')
|
||||
def test_TimeDistributed_batchnorm(self):
|
||||
with self.cached_session():
|
||||
# test that wrapped BN updates still work.
|
||||
@ -206,8 +205,6 @@ class TimeDistributedTest(keras_parameterized.TestCase):
|
||||
# Assert that mean and variance changed.
|
||||
assert not np.array_equal(td.get_weights()[2], np.array([0, 0]))
|
||||
assert not np.array_equal(td.get_weights()[3], np.array([1, 1]))
|
||||
# Verify input_map has one mapping from inputs to reshaped inputs.
|
||||
self.assertEqual(len(td._input_map.keys()), 1)
|
||||
|
||||
def test_TimeDistributed_trainable(self):
|
||||
# test layers that need learning_phase to be set
|
||||
|
Loading…
x
Reference in New Issue
Block a user