Removing wrapper._input_map field to avoid potential memory leak.

Fix https://github.com/tensorflow/tensorflow/issues/33178.

PiperOrigin-RevId: 292043221
Change-Id: Ife2fa9a2adf50424bb2c932044fe3db5f4bb42d5
This commit is contained in:
Scott Zhu 2020-01-28 16:55:50 -08:00 committed by TensorFlower Gardener
parent c3f78d1c1c
commit d064c6fc9a
2 changed files with 6 additions and 13 deletions

View File

@ -51,10 +51,6 @@ class Wrapper(Layer):
def __init__(self, layer, **kwargs):
assert isinstance(layer, Layer)
self.layer = layer
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).
self._input_map = {}
super(Wrapper, self).__init__(**kwargs)
def build(self, input_shape=None):
@ -258,9 +254,7 @@ class TimeDistributed(Wrapper):
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
input_uid = generic_utils.object_list_uid(inputs)
inputs = array_ops.reshape(inputs, inner_input_shape)
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
@ -314,15 +308,17 @@ class TimeDistributed(Wrapper):
# cases need to call the layer.compute_mask when input_mask is None:
# Masking layer and Embedding layer with mask_zero
input_shape = K.int_shape(inputs)
if input_shape[0]:
# batch size matters, we currently do not handle mask explicitly
if input_shape[0] and not self._always_use_reshape or isinstance(
inputs, ragged_tensor.RaggedTensor):
# batch size matters, we currently do not handle mask explicitly, or if
# the layer always uses reshape approach, or the input is a ragged tensor.
return mask
inner_mask = mask
if inner_mask is not None:
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
inner_mask = K.reshape(inner_mask, inner_mask_shape)
input_uid = generic_utils.object_list_uid(inputs)
inner_inputs = self._input_map.get(input_uid, inputs)
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
inner_inputs = array_ops.reshape(inputs, inner_input_shape)
output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
if output_mask is None:
if mask is None:

View File

@ -186,7 +186,6 @@ class TimeDistributedTest(keras_parameterized.TestCase):
y = model.predict(np.random.random((10, 3, 2)))
self.assertAllClose(np.mean(y), 0., atol=1e-1, rtol=1e-1)
@tf_test_util.run_v1_only(reason='b/148248386')
def test_TimeDistributed_batchnorm(self):
with self.cached_session():
# test that wrapped BN updates still work.
@ -206,8 +205,6 @@ class TimeDistributedTest(keras_parameterized.TestCase):
# Assert that mean and variance changed.
assert not np.array_equal(td.get_weights()[2], np.array([0, 0]))
assert not np.array_equal(td.get_weights()[3], np.array([1, 1]))
# Verify input_map has one mapping from inputs to reshaped inputs.
self.assertEqual(len(td._input_map.keys()), 1)
def test_TimeDistributed_trainable(self):
# test layers that need learning_phase to be set