Trying to reduce the memory leak when validation_split is used in model.fit().

Not sure about the life cycle of the eager tensor has, but the memory increase is greatly reduced after this change when testing with np input and validation split.
It is possible that there are still issues in the code base that cause the eager tensor to be not released, but we probably should aggressively convert np or pd data into eager tensor.

Memory log before change:
#--- Run 1 of 20 memory used (MB): 420.94592
#--- Run 2 of 20 memory used (MB): 455.458816
#--- Run 3 of 20 memory used (MB): 480.89088
#--- Run 4 of 20 memory used (MB): 504.799232
#--- Run 5 of 20 memory used (MB): 465.563648
#--- Run 6 of 20 memory used (MB): 485.797888
#--- Run 7 of 20 memory used (MB): 506.544128
#--- Run 8 of 20 memory used (MB): 526.76608
#--- Run 9 of 20 memory used (MB): 547.782656
#--- Run 10 of 20 memory used (MB): 487.981056
#--- Run 11 of 20 memory used (MB): 508.862464
#--- Run 12 of 20 memory used (MB): 528.904192
#--- Run 13 of 20 memory used (MB): 549.933056
#--- Run 14 of 20 memory used (MB): 570.032128
#--- Run 15 of 20 memory used (MB): 510.455808
#--- Run 16 of 20 memory used (MB): 530.501632
#--- Run 17 of 20 memory used (MB): 551.559168
#--- Run 18 of 20 memory used (MB): 571.408384
#--- Run 19 of 20 memory used (MB): 529.518592
#--- Run 20 of 20 memory used (MB): 549.376

Memory log after change:
#--- Run 1 of 20 memory used (MB): 441.933824
#--- Run 2 of 20 memory used (MB): 463.753216
#--- Run 3 of 20 memory used (MB): 465.801216
#--- Run 4 of 20 memory used (MB): 466.366464
#--- Run 5 of 20 memory used (MB): 467.0464
#--- Run 6 of 20 memory used (MB): 467.709952
#--- Run 7 of 20 memory used (MB): 468.668416
#--- Run 8 of 20 memory used (MB): 468.62336
#--- Run 9 of 20 memory used (MB): 474.35776
#--- Run 10 of 20 memory used (MB): 474.353664
#--- Run 11 of 20 memory used (MB): 474.472448
#--- Run 12 of 20 memory used (MB): 474.648576
#--- Run 13 of 20 memory used (MB): 474.697728
#--- Run 14 of 20 memory used (MB): 474.750976
#--- Run 15 of 20 memory used (MB): 474.804224
#--- Run 16 of 20 memory used (MB): 474.800128
#--- Run 17 of 20 memory used (MB): 474.857472
#--- Run 18 of 20 memory used (MB): 474.918912
#--- Run 19 of 20 memory used (MB): 475.086848
#--- Run 20 of 20 memory used (MB): 475.348992

PiperOrigin-RevId: 314746357
Change-Id: I84cd784059ae4aec827a6e908df2ca738b2dac48
This commit is contained in:
Scott Zhu 2020-06-04 09:44:23 -07:00 committed by TensorFlower Gardener
parent fe33f393b8
commit ce2f9824ee
3 changed files with 17 additions and 61 deletions

View File

@ -1365,8 +1365,10 @@ def expand_1d(data):
return nest.map_structure(_expand_single_1d_tensor, data)
def train_validation_split(arrays, validation_split, shuffle=True):
"""Split arrays into random train and validation subsets.
def train_validation_split(arrays, validation_split):
"""Split arrays into train and validation subsets in deterministic order.
The last part of data will become validation data.
Arguments:
arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
@ -1374,10 +1376,6 @@ def train_validation_split(arrays, validation_split, shuffle=True):
validation_split: Float between 0 and 1. The proportion of the dataset to
include in the validation split. The rest of the dataset will be included
in the training split.
shuffle: Bool. Whether to shuffle the data before performing a split. If
`False`, the last `validation_split` fraction of that training data will
become the validation split.
Returns:
`(train_arrays, validation_arrays)`
"""
@ -1406,12 +1404,7 @@ def train_validation_split(arrays, validation_split, shuffle=True):
# Assumes all arrays have the same batch shape or are `None`.
batch_dim = int(first_non_none.shape[0])
indices = ops.convert_to_tensor_v2(range(batch_dim))
if shuffle:
indices = random_ops.random_shuffle(indices)
split_at = int(math.floor(batch_dim * (1. - validation_split)))
train_indices = indices[:split_at]
val_indices = indices[split_at:]
if split_at == 0 or split_at == batch_dim:
raise ValueError(
@ -1421,16 +1414,15 @@ def train_validation_split(arrays, validation_split, shuffle=True):
"different value for the `validation_split` argument." .format(
batch_dim=batch_dim, validation_split=validation_split))
def _split(t, indices):
def _split(t, start, end):
if t is None:
return t
t = ops.convert_to_tensor_v2(t)
return array_ops.gather_v2(t, indices)
return t[start:end]
train_arrays = nest.map_structure(
functools.partial(_split, indices=train_indices), arrays)
functools.partial(_split, start=0, end=split_at), arrays)
val_arrays = nest.map_structure(
functools.partial(_split, indices=val_indices), arrays)
functools.partial(_split, start=split_at, end=batch_dim), arrays)
return train_arrays, val_arrays

View File

@ -985,7 +985,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
class TestValidationSplit(keras_parameterized.TestCase):
@parameterized.named_parameters(('numpy_arrays', True), ('tensors', False))
def test_validation_split_shuffled(self, use_numpy):
def test_validation_split_unshuffled(self, use_numpy):
if use_numpy:
x = np.array([0, 1, 2, 3, 4])
y = np.array([0, 2, 4, 6, 8])
@ -998,48 +998,13 @@ class TestValidationSplit(keras_parameterized.TestCase):
(train_x, train_y, train_sw), (val_x, val_y, val_sw) = (
data_adapter.train_validation_split((x, y, sw), validation_split=0.2))
self.assertEqual(int(train_x.shape[0]), 4)
self.assertEqual(int(train_y.shape[0]), 4)
self.assertEqual(int(train_sw.shape[0]), 4)
for i in range(4):
# Check that all arrays were shuffled in identical order.
self.assertEqual(2 * train_x[i].numpy(), train_y[i].numpy())
self.assertEqual(2 * train_y[i].numpy(), train_sw[i].numpy())
self.assertEqual(int(val_x.shape[0]), 1)
self.assertEqual(int(val_y.shape[0]), 1)
self.assertEqual(int(val_sw.shape[0]), 1)
for i in range(1):
# Check that all arrays were shuffled in identical order.
self.assertEqual(2 * train_x[i].numpy(), train_y[i].numpy())
self.assertEqual(2 * train_y[i].numpy(), train_sw[i].numpy())
# Check that arrays contain expected values.
self.assertEqual(
sorted(array_ops.concat([train_x, val_x], axis=0).numpy().tolist()),
sorted(ops.convert_to_tensor_v2(x).numpy().tolist()))
self.assertEqual(
sorted(array_ops.concat([train_y, val_y], axis=0).numpy().tolist()),
sorted(ops.convert_to_tensor_v2(y).numpy().tolist()))
self.assertEqual(
sorted(array_ops.concat([train_sw, val_sw], axis=0).numpy().tolist()),
sorted(ops.convert_to_tensor_v2(sw).numpy().tolist()))
@parameterized.named_parameters(('numpy_arrays', True), ('tensors', False))
def test_validation_split_unshuffled(self, use_numpy):
if use_numpy:
x = np.array([0, 1, 2, 3, 4])
y = np.array([0, 2, 4, 6, 8])
sw = np.array([0, 4, 8, 12, 16])
else:
x = ops.convert_to_tensor_v2([0, 1, 2, 3, 4])
y = ops.convert_to_tensor_v2([0, 2, 4, 6, 8])
sw = ops.convert_to_tensor_v2([0, 4, 8, 12, 16])
(train_x, train_y, train_sw), (val_x, val_y, val_sw) = (
data_adapter.train_validation_split((x, y, sw),
validation_split=0.2,
shuffle=False))
train_x = ops.convert_to_tensor_v2(train_x)
train_y = ops.convert_to_tensor_v2(train_y)
train_sw = ops.convert_to_tensor_v2(train_sw)
val_x = ops.convert_to_tensor_v2(val_x)
val_y = ops.convert_to_tensor_v2(val_y)
val_sw = ops.convert_to_tensor_v2(val_sw)
self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3])
self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6])

View File

@ -1030,9 +1030,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
# Create the validation data using the training data. Only supported for
# `Tensor` and `NumPy` input.
(x, y, sample_weight), validation_data = (
data_adapter.train_validation_split((x, y, sample_weight),
validation_split=validation_split,
shuffle=False))
data_adapter.train_validation_split(
(x, y, sample_weight), validation_split=validation_split))
with self.distribute_strategy.scope(), \
training_utils.RespectCompiledTrainableState(self):