diff --git a/tensorflow/python/keras/preprocessing/BUILD b/tensorflow/python/keras/preprocessing/BUILD index 640e47a1d44..7c75e45fc58 100644 --- a/tensorflow/python/keras/preprocessing/BUILD +++ b/tensorflow/python/keras/preprocessing/BUILD @@ -19,6 +19,7 @@ py_library( ":image", ":sequence", ":text", + ":timeseries", ], ) @@ -47,6 +48,19 @@ py_library( ], ) +py_library( + name = "timeseries", + srcs = [ + "timeseries.py", + ], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + py_library( name = "text", srcs = [ @@ -104,3 +118,16 @@ tf_py_test( "//third_party/py/numpy", ], ) + +tf_py_test( + name = "timeseries_test", + size = "small", + srcs = ["timeseries_test.py"], + python_version = "PY3", + deps = [ + ":timeseries", + "//tensorflow/python:client_testlib", + "//tensorflow/python/compat:v2_compat", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/python/keras/preprocessing/timeseries.py b/tensorflow/python/keras/preprocessing/timeseries.py new file mode 100644 index 00000000000..ca41f1952e3 --- /dev/null +++ b/tensorflow/python/keras/preprocessing/timeseries.py @@ -0,0 +1,181 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras timeseries dataset utilities.""" +# pylint: disable=g-classes-have-attributes +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def timeseries_dataset( + data, + targets, + sequence_length, + sampling_rate=1, + sequence_stride=1, + batch_size=128, + shuffle=False, + seed=None, + start_index=None, + end_index=None): + """Utility function for generating batches of temporal data. + + This function takes in a sequence of data-points gathered at + equal intervals, along with time series parameters such as + spacing between two sequence, length of history, etc., to produce batches for + training/validation. + + Arguments: + data: Indexable generator (such as a list or a Numpy array) + containing consecutive data points (timesteps). + Axis 0 is expected to be the time dimension. + targets: Targets corresponding to timesteps in `data`. + It should have same length as `data`. + Pass None if you don't have target data (in this case the dataset will + only yield the input data). + sequence_length: Length of the output sequences (in number of timesteps). + sampling_rate: Period between successive individual timesteps + within sequences. For rate `r`, timesteps + `data[i], data[i + r], ... data[i + sequence_length]` + are used for create a sample sequence. + sequence_stride: Period between successive output sequences. + For stride `s`, output samples would + start at index `data[i]`, `data[i + s]`, `data[i + 2 * s]`, etc. + batch_size: Number of timeseries samples in each batch + (except maybe the last one). + shuffle: Whether to shuffle output samples, + or instead draw them in chronological order. + seed: Optional int; random seed for shuffling. + start_index: Optional int; data points earlier (exclusive) + than `start_index` will not be used + in the output sequences. This is useful to reserve part of the + data for test or validation. + end_index: Optional int; data points later (exclusive) than `end_index` + will not be used in the output sequences. + This is useful to reserve part of the data for test or validation. + + Returns: + A tf.data.Dataset instance. If `targets` was pass, the dataset yields + tuple `(batch_of_sequences, batch_of_targets)`. If not, the dataset yields + only `batch_of_sequences`. + + Example: + Consider indices `[0, 1, ... 99]`. + With `sequence_length=10, sampling_rate=2, sequence_stride=3`, + `shuffle=False`, the dataset will yield batches of sequences + composed of the following indices: + + ``` + First sequence: [0 2 4 6 8 10 12 14 16 18] + Second sequence: [3 5 7 9 11 13 15 17 19 21] + Third sequence: [6 8 10 12 14 16 18 20 22 24] + ... + Last sequence: [78 80 82 84 86 88 90 92 94 96] + ``` + + In this case the last 3 data points are discarded since no full sequence + can be generated to include them (the next sequence would have started + at index 81, and thus its last step would have gone over 99). + """ + # Validate the shape of data and targets + if targets is not None and len(targets) != len(data): + raise ValueError('Expected data and targets to have the same number of ' + 'time steps (axis 0) but got ' + 'shape(data) = %s; shape(targets) = %s.' % + (data.shape, targets.shape)) + if start_index and (start_index < 0 or start_index >= len(data)): + raise ValueError('start_index must be higher than 0 and lower than the ' + 'length of the data. Got: start_index=%s ' + 'for data of length %s.' % (start_index, len(data))) + if end_index: + if start_index and end_index <= start_index: + raise ValueError('end_index must be higher than start_index. Got: ' + 'start_index=%s, end_index=%s.' % + (start_index, end_index)) + if end_index >= len(data): + raise ValueError('end_index must be lower than the length of the data. ' + 'Got: end_index=%s' % (end_index,)) + if end_index <= 0: + raise ValueError('end_index must be higher than 0. ' + 'Got: end_index=%s' % (end_index,)) + + # Validate strides + if sampling_rate <= 0 or sampling_rate >= len(data): + raise ValueError( + 'sampling_rate must be higher than 0 and lower than ' + 'the length of the data. Got: ' + 'sampling_rate=%s for data of length %s.' % (sampling_rate, len(data))) + if sequence_stride <= 0 or sequence_stride >= len(data): + raise ValueError( + 'sequence_stride must be higher than 0 and lower than ' + 'the length of the data. Got: sequence_stride=%s ' + 'for data of length %s.' % (sequence_stride, len(data))) + + if start_index is None: + start_index = 0 + if end_index is None: + end_index = len(data) + + # Determine the lowest dtype to store start positions (to lower memory usage). + num_seqs = end_index - start_index - (sequence_length * sampling_rate) + 1 + if num_seqs < 2147483647: + index_dtype = 'int32' + else: + index_dtype = 'int64' + + # Generate start positions + start_positions = np.arange(0, num_seqs, sequence_stride, dtype=index_dtype) + if shuffle: + if seed is None: + seed = np.random.randint(1e6) + rng = np.random.RandomState(seed) + rng.shuffle(start_positions) + + sequence_length = math_ops.cast(sequence_length, dtype=index_dtype) + sampling_rate = math_ops.cast(sampling_rate, dtype=index_dtype) + + # For each initial window position, generates indices of the window elements + indices = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(len(start_positions)), + dataset_ops.Dataset.from_tensors(start_positions).repeat())).map( + lambda i, positions: math_ops.range( # pylint: disable=g-long-lambda + positions[i], + positions[i] + sequence_length * sampling_rate, + sampling_rate), + num_parallel_calls=dataset_ops.AUTOTUNE) + + dataset = sequences_from_indices(data, indices, start_index, end_index) + if targets is not None: + target_ds = sequences_from_indices(targets, indices, start_index, end_index) + dataset = dataset_ops.Dataset.zip((dataset, target_ds)) + if shuffle: + # Shuffle locally at each iteration + dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) + dataset = dataset.batch(batch_size) + return dataset + + +def sequences_from_indices(array, indices_ds, start_index, end_index): + dataset = dataset_ops.Dataset.from_tensors(array[start_index : end_index]) + dataset = dataset_ops.Dataset.zip((dataset.repeat(), indices_ds)).map( + lambda steps, inds: array_ops.gather(steps, inds), # pylint: disable=unnecessary-lambda + num_parallel_calls=dataset_ops.AUTOTUNE) + return dataset diff --git a/tensorflow/python/keras/preprocessing/timeseries_test.py b/tensorflow/python/keras/preprocessing/timeseries_test.py new file mode 100644 index 00000000000..ab1640191bf --- /dev/null +++ b/tensorflow/python/keras/preprocessing/timeseries_test.py @@ -0,0 +1,162 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for timeseries.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.compat import v2_compat +from tensorflow.python.keras.preprocessing import timeseries +from tensorflow.python.platform import test + + +class TimeseriesDatasetTest(test.TestCase): + + def test_basics(self): + # Test ordering, targets, sequence length, batch size + data = np.arange(100) + targets = data * 2 + dataset = timeseries.timeseries_dataset( + data, targets, sequence_length=9, batch_size=5) + # Expect 19 batches + for i, batch in enumerate(dataset): + self.assertLen(batch, 2) + if i < 18: + self.assertEqual(batch[0].shape, (5, 9)) + if i == 18: + # Last batch: size 2 + self.assertEqual(batch[0].shape, (2, 9)) + # Check target values + self.assertAllClose(batch[0] * 2, batch[1]) + for j in range(min(5, len(batch[0]))): + # Check each sample in the batch + self.assertAllClose(batch[0][j], np.arange(i * 5 + j, i * 5 + j + 9)) + + def test_no_targets(self): + data = np.arange(50) + dataset = timeseries.timeseries_dataset( + data, None, sequence_length=10, batch_size=5) + # Expect 9 batches + for i, batch in enumerate(dataset): + if i < 8: + self.assertEqual(batch.shape, (5, 10)) + elif i == 8: + self.assertEqual(batch.shape, (1, 10)) + for j in range(min(5, len(batch))): + # Check each sample in the batch + self.assertAllClose(batch[j], np.arange(i * 5 + j, i * 5 + j + 10)) + + def test_shuffle(self): + # Test cross-epoch random order and seed determinism + data = np.arange(10) + targets = data * 2 + dataset = timeseries.timeseries_dataset( + data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123) + first_seq = None + for x, y in dataset.take(1): + self.assertNotAllClose(x, np.arange(0, 5)) + self.assertAllClose(x * 2, y) + first_seq = x + # Check that a new iteration with the same dataset yields different results + for x, _ in dataset.take(1): + self.assertNotAllClose(x, first_seq) + # Check determism with same seed + dataset = timeseries.timeseries_dataset( + data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123) + for x, _ in dataset.take(1): + self.assertAllClose(x, first_seq) + + def test_sampling_rate(self): + data = np.arange(100) + targets = data * 2 + dataset = timeseries.timeseries_dataset( + data, targets, sequence_length=9, batch_size=5, sampling_rate=2) + for i, batch in enumerate(dataset): + self.assertLen(batch, 2) + if i < 16: + self.assertEqual(batch[0].shape, (5, 9)) + if i == 16: + # Last batch: size 3 + self.assertEqual(batch[0].shape, (3, 9)) + # Check target values + self.assertAllClose(batch[0] * 2, batch[1]) + for j in range(min(5, len(batch[0]))): + # Check each sample in the batch + start_index = i * 5 + j + end_index = start_index + 9 * 2 + self.assertAllClose(batch[0][j], + np.arange(start_index, end_index, 2)) + + def test_sequence_stride(self): + data = np.arange(100) + targets = data * 2 + dataset = timeseries.timeseries_dataset( + data, targets, sequence_length=9, batch_size=5, sequence_stride=3) + for i, batch in enumerate(dataset): + self.assertLen(batch, 2) + if i < 6: + self.assertEqual(batch[0].shape, (5, 9)) + if i == 6: + # Last batch: size 1 + self.assertEqual(batch[0].shape, (1, 9)) + # Check target values + self.assertAllClose(batch[0] * 2, batch[1]) + for j in range(min(5, len(batch[0]))): + # Check each sample in the batch + start_index = i * 5 * 3 + j * 3 + end_index = start_index + 9 + self.assertAllClose(batch[0][j], + np.arange(start_index, end_index)) + + def test_start_and_end_index(self): + data = np.arange(100) + dataset = timeseries.timeseries_dataset( + data, None, + sequence_length=9, batch_size=5, sequence_stride=3, sampling_rate=2, + start_index=10, end_index=90) + for batch in dataset: + self.assertAllLess(batch[0], 90) + self.assertAllGreater(batch[0], 9) + + def test_errors(self): + # bad targets + with self.assertRaisesRegex(ValueError, + 'data and targets to have the same number'): + _ = timeseries.timeseries_dataset(np.arange(10), np.arange(9), 3) + # bad start index + with self.assertRaisesRegex(ValueError, 'start_index must be '): + _ = timeseries.timeseries_dataset(np.arange(10), None, 3, start_index=-1) + with self.assertRaisesRegex(ValueError, 'start_index must be '): + _ = timeseries.timeseries_dataset(np.arange(10), None, 3, start_index=11) + # bad end index + with self.assertRaisesRegex(ValueError, 'end_index must be '): + _ = timeseries.timeseries_dataset(np.arange(10), None, 3, end_index=-1) + with self.assertRaisesRegex(ValueError, 'end_index must be '): + _ = timeseries.timeseries_dataset(np.arange(10), None, 3, end_index=11) + # bad sampling_rate + with self.assertRaisesRegex(ValueError, 'sampling_rate must be '): + _ = timeseries.timeseries_dataset(np.arange(10), None, 3, sampling_rate=0) + # bad sequence stride + with self.assertRaisesRegex(ValueError, 'sequence_stride must be '): + _ = timeseries.timeseries_dataset( + np.arange(10), None, 3, sequence_stride=0) + + +if __name__ == '__main__': + v2_compat.enable_v2_behavior() + test.main()