Implement timeseries_dataset utility (as per prior RFC).
PiperOrigin-RevId: 300437253 Change-Id: If453ad0f23bde9826ca8942575bb77160fd6fb57
This commit is contained in:
parent
73c6de788c
commit
aedb53e371
@ -19,6 +19,7 @@ py_library(
|
||||
":image",
|
||||
":sequence",
|
||||
":text",
|
||||
":timeseries",
|
||||
],
|
||||
)
|
||||
|
||||
@ -47,6 +48,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "timeseries",
|
||||
srcs = [
|
||||
"timeseries.py",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "text",
|
||||
srcs = [
|
||||
@ -104,3 +118,16 @@ tf_py_test(
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "timeseries_test",
|
||||
size = "small",
|
||||
srcs = ["timeseries_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":timeseries",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
181
tensorflow/python/keras/preprocessing/timeseries.py
Normal file
181
tensorflow/python/keras/preprocessing/timeseries.py
Normal file
@ -0,0 +1,181 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras timeseries dataset utilities."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def timeseries_dataset(
|
||||
data,
|
||||
targets,
|
||||
sequence_length,
|
||||
sampling_rate=1,
|
||||
sequence_stride=1,
|
||||
batch_size=128,
|
||||
shuffle=False,
|
||||
seed=None,
|
||||
start_index=None,
|
||||
end_index=None):
|
||||
"""Utility function for generating batches of temporal data.
|
||||
|
||||
This function takes in a sequence of data-points gathered at
|
||||
equal intervals, along with time series parameters such as
|
||||
spacing between two sequence, length of history, etc., to produce batches for
|
||||
training/validation.
|
||||
|
||||
Arguments:
|
||||
data: Indexable generator (such as a list or a Numpy array)
|
||||
containing consecutive data points (timesteps).
|
||||
Axis 0 is expected to be the time dimension.
|
||||
targets: Targets corresponding to timesteps in `data`.
|
||||
It should have same length as `data`.
|
||||
Pass None if you don't have target data (in this case the dataset will
|
||||
only yield the input data).
|
||||
sequence_length: Length of the output sequences (in number of timesteps).
|
||||
sampling_rate: Period between successive individual timesteps
|
||||
within sequences. For rate `r`, timesteps
|
||||
`data[i], data[i + r], ... data[i + sequence_length]`
|
||||
are used for create a sample sequence.
|
||||
sequence_stride: Period between successive output sequences.
|
||||
For stride `s`, output samples would
|
||||
start at index `data[i]`, `data[i + s]`, `data[i + 2 * s]`, etc.
|
||||
batch_size: Number of timeseries samples in each batch
|
||||
(except maybe the last one).
|
||||
shuffle: Whether to shuffle output samples,
|
||||
or instead draw them in chronological order.
|
||||
seed: Optional int; random seed for shuffling.
|
||||
start_index: Optional int; data points earlier (exclusive)
|
||||
than `start_index` will not be used
|
||||
in the output sequences. This is useful to reserve part of the
|
||||
data for test or validation.
|
||||
end_index: Optional int; data points later (exclusive) than `end_index`
|
||||
will not be used in the output sequences.
|
||||
This is useful to reserve part of the data for test or validation.
|
||||
|
||||
Returns:
|
||||
A tf.data.Dataset instance. If `targets` was pass, the dataset yields
|
||||
tuple `(batch_of_sequences, batch_of_targets)`. If not, the dataset yields
|
||||
only `batch_of_sequences`.
|
||||
|
||||
Example:
|
||||
Consider indices `[0, 1, ... 99]`.
|
||||
With `sequence_length=10, sampling_rate=2, sequence_stride=3`,
|
||||
`shuffle=False`, the dataset will yield batches of sequences
|
||||
composed of the following indices:
|
||||
|
||||
```
|
||||
First sequence: [0 2 4 6 8 10 12 14 16 18]
|
||||
Second sequence: [3 5 7 9 11 13 15 17 19 21]
|
||||
Third sequence: [6 8 10 12 14 16 18 20 22 24]
|
||||
...
|
||||
Last sequence: [78 80 82 84 86 88 90 92 94 96]
|
||||
```
|
||||
|
||||
In this case the last 3 data points are discarded since no full sequence
|
||||
can be generated to include them (the next sequence would have started
|
||||
at index 81, and thus its last step would have gone over 99).
|
||||
"""
|
||||
# Validate the shape of data and targets
|
||||
if targets is not None and len(targets) != len(data):
|
||||
raise ValueError('Expected data and targets to have the same number of '
|
||||
'time steps (axis 0) but got '
|
||||
'shape(data) = %s; shape(targets) = %s.' %
|
||||
(data.shape, targets.shape))
|
||||
if start_index and (start_index < 0 or start_index >= len(data)):
|
||||
raise ValueError('start_index must be higher than 0 and lower than the '
|
||||
'length of the data. Got: start_index=%s '
|
||||
'for data of length %s.' % (start_index, len(data)))
|
||||
if end_index:
|
||||
if start_index and end_index <= start_index:
|
||||
raise ValueError('end_index must be higher than start_index. Got: '
|
||||
'start_index=%s, end_index=%s.' %
|
||||
(start_index, end_index))
|
||||
if end_index >= len(data):
|
||||
raise ValueError('end_index must be lower than the length of the data. '
|
||||
'Got: end_index=%s' % (end_index,))
|
||||
if end_index <= 0:
|
||||
raise ValueError('end_index must be higher than 0. '
|
||||
'Got: end_index=%s' % (end_index,))
|
||||
|
||||
# Validate strides
|
||||
if sampling_rate <= 0 or sampling_rate >= len(data):
|
||||
raise ValueError(
|
||||
'sampling_rate must be higher than 0 and lower than '
|
||||
'the length of the data. Got: '
|
||||
'sampling_rate=%s for data of length %s.' % (sampling_rate, len(data)))
|
||||
if sequence_stride <= 0 or sequence_stride >= len(data):
|
||||
raise ValueError(
|
||||
'sequence_stride must be higher than 0 and lower than '
|
||||
'the length of the data. Got: sequence_stride=%s '
|
||||
'for data of length %s.' % (sequence_stride, len(data)))
|
||||
|
||||
if start_index is None:
|
||||
start_index = 0
|
||||
if end_index is None:
|
||||
end_index = len(data)
|
||||
|
||||
# Determine the lowest dtype to store start positions (to lower memory usage).
|
||||
num_seqs = end_index - start_index - (sequence_length * sampling_rate) + 1
|
||||
if num_seqs < 2147483647:
|
||||
index_dtype = 'int32'
|
||||
else:
|
||||
index_dtype = 'int64'
|
||||
|
||||
# Generate start positions
|
||||
start_positions = np.arange(0, num_seqs, sequence_stride, dtype=index_dtype)
|
||||
if shuffle:
|
||||
if seed is None:
|
||||
seed = np.random.randint(1e6)
|
||||
rng = np.random.RandomState(seed)
|
||||
rng.shuffle(start_positions)
|
||||
|
||||
sequence_length = math_ops.cast(sequence_length, dtype=index_dtype)
|
||||
sampling_rate = math_ops.cast(sampling_rate, dtype=index_dtype)
|
||||
|
||||
# For each initial window position, generates indices of the window elements
|
||||
indices = dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.range(len(start_positions)),
|
||||
dataset_ops.Dataset.from_tensors(start_positions).repeat())).map(
|
||||
lambda i, positions: math_ops.range( # pylint: disable=g-long-lambda
|
||||
positions[i],
|
||||
positions[i] + sequence_length * sampling_rate,
|
||||
sampling_rate),
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
|
||||
dataset = sequences_from_indices(data, indices, start_index, end_index)
|
||||
if targets is not None:
|
||||
target_ds = sequences_from_indices(targets, indices, start_index, end_index)
|
||||
dataset = dataset_ops.Dataset.zip((dataset, target_ds))
|
||||
if shuffle:
|
||||
# Shuffle locally at each iteration
|
||||
dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
|
||||
dataset = dataset.batch(batch_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def sequences_from_indices(array, indices_ds, start_index, end_index):
|
||||
dataset = dataset_ops.Dataset.from_tensors(array[start_index : end_index])
|
||||
dataset = dataset_ops.Dataset.zip((dataset.repeat(), indices_ds)).map(
|
||||
lambda steps, inds: array_ops.gather(steps, inds), # pylint: disable=unnecessary-lambda
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
return dataset
|
162
tensorflow/python/keras/preprocessing/timeseries_test.py
Normal file
162
tensorflow/python/keras/preprocessing/timeseries_test.py
Normal file
@ -0,0 +1,162 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for timeseries."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.keras.preprocessing import timeseries
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TimeseriesDatasetTest(test.TestCase):
|
||||
|
||||
def test_basics(self):
|
||||
# Test ordering, targets, sequence length, batch size
|
||||
data = np.arange(100)
|
||||
targets = data * 2
|
||||
dataset = timeseries.timeseries_dataset(
|
||||
data, targets, sequence_length=9, batch_size=5)
|
||||
# Expect 19 batches
|
||||
for i, batch in enumerate(dataset):
|
||||
self.assertLen(batch, 2)
|
||||
if i < 18:
|
||||
self.assertEqual(batch[0].shape, (5, 9))
|
||||
if i == 18:
|
||||
# Last batch: size 2
|
||||
self.assertEqual(batch[0].shape, (2, 9))
|
||||
# Check target values
|
||||
self.assertAllClose(batch[0] * 2, batch[1])
|
||||
for j in range(min(5, len(batch[0]))):
|
||||
# Check each sample in the batch
|
||||
self.assertAllClose(batch[0][j], np.arange(i * 5 + j, i * 5 + j + 9))
|
||||
|
||||
def test_no_targets(self):
|
||||
data = np.arange(50)
|
||||
dataset = timeseries.timeseries_dataset(
|
||||
data, None, sequence_length=10, batch_size=5)
|
||||
# Expect 9 batches
|
||||
for i, batch in enumerate(dataset):
|
||||
if i < 8:
|
||||
self.assertEqual(batch.shape, (5, 10))
|
||||
elif i == 8:
|
||||
self.assertEqual(batch.shape, (1, 10))
|
||||
for j in range(min(5, len(batch))):
|
||||
# Check each sample in the batch
|
||||
self.assertAllClose(batch[j], np.arange(i * 5 + j, i * 5 + j + 10))
|
||||
|
||||
def test_shuffle(self):
|
||||
# Test cross-epoch random order and seed determinism
|
||||
data = np.arange(10)
|
||||
targets = data * 2
|
||||
dataset = timeseries.timeseries_dataset(
|
||||
data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123)
|
||||
first_seq = None
|
||||
for x, y in dataset.take(1):
|
||||
self.assertNotAllClose(x, np.arange(0, 5))
|
||||
self.assertAllClose(x * 2, y)
|
||||
first_seq = x
|
||||
# Check that a new iteration with the same dataset yields different results
|
||||
for x, _ in dataset.take(1):
|
||||
self.assertNotAllClose(x, first_seq)
|
||||
# Check determism with same seed
|
||||
dataset = timeseries.timeseries_dataset(
|
||||
data, targets, sequence_length=5, batch_size=1, shuffle=True, seed=123)
|
||||
for x, _ in dataset.take(1):
|
||||
self.assertAllClose(x, first_seq)
|
||||
|
||||
def test_sampling_rate(self):
|
||||
data = np.arange(100)
|
||||
targets = data * 2
|
||||
dataset = timeseries.timeseries_dataset(
|
||||
data, targets, sequence_length=9, batch_size=5, sampling_rate=2)
|
||||
for i, batch in enumerate(dataset):
|
||||
self.assertLen(batch, 2)
|
||||
if i < 16:
|
||||
self.assertEqual(batch[0].shape, (5, 9))
|
||||
if i == 16:
|
||||
# Last batch: size 3
|
||||
self.assertEqual(batch[0].shape, (3, 9))
|
||||
# Check target values
|
||||
self.assertAllClose(batch[0] * 2, batch[1])
|
||||
for j in range(min(5, len(batch[0]))):
|
||||
# Check each sample in the batch
|
||||
start_index = i * 5 + j
|
||||
end_index = start_index + 9 * 2
|
||||
self.assertAllClose(batch[0][j],
|
||||
np.arange(start_index, end_index, 2))
|
||||
|
||||
def test_sequence_stride(self):
|
||||
data = np.arange(100)
|
||||
targets = data * 2
|
||||
dataset = timeseries.timeseries_dataset(
|
||||
data, targets, sequence_length=9, batch_size=5, sequence_stride=3)
|
||||
for i, batch in enumerate(dataset):
|
||||
self.assertLen(batch, 2)
|
||||
if i < 6:
|
||||
self.assertEqual(batch[0].shape, (5, 9))
|
||||
if i == 6:
|
||||
# Last batch: size 1
|
||||
self.assertEqual(batch[0].shape, (1, 9))
|
||||
# Check target values
|
||||
self.assertAllClose(batch[0] * 2, batch[1])
|
||||
for j in range(min(5, len(batch[0]))):
|
||||
# Check each sample in the batch
|
||||
start_index = i * 5 * 3 + j * 3
|
||||
end_index = start_index + 9
|
||||
self.assertAllClose(batch[0][j],
|
||||
np.arange(start_index, end_index))
|
||||
|
||||
def test_start_and_end_index(self):
|
||||
data = np.arange(100)
|
||||
dataset = timeseries.timeseries_dataset(
|
||||
data, None,
|
||||
sequence_length=9, batch_size=5, sequence_stride=3, sampling_rate=2,
|
||||
start_index=10, end_index=90)
|
||||
for batch in dataset:
|
||||
self.assertAllLess(batch[0], 90)
|
||||
self.assertAllGreater(batch[0], 9)
|
||||
|
||||
def test_errors(self):
|
||||
# bad targets
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'data and targets to have the same number'):
|
||||
_ = timeseries.timeseries_dataset(np.arange(10), np.arange(9), 3)
|
||||
# bad start index
|
||||
with self.assertRaisesRegex(ValueError, 'start_index must be '):
|
||||
_ = timeseries.timeseries_dataset(np.arange(10), None, 3, start_index=-1)
|
||||
with self.assertRaisesRegex(ValueError, 'start_index must be '):
|
||||
_ = timeseries.timeseries_dataset(np.arange(10), None, 3, start_index=11)
|
||||
# bad end index
|
||||
with self.assertRaisesRegex(ValueError, 'end_index must be '):
|
||||
_ = timeseries.timeseries_dataset(np.arange(10), None, 3, end_index=-1)
|
||||
with self.assertRaisesRegex(ValueError, 'end_index must be '):
|
||||
_ = timeseries.timeseries_dataset(np.arange(10), None, 3, end_index=11)
|
||||
# bad sampling_rate
|
||||
with self.assertRaisesRegex(ValueError, 'sampling_rate must be '):
|
||||
_ = timeseries.timeseries_dataset(np.arange(10), None, 3, sampling_rate=0)
|
||||
# bad sequence stride
|
||||
with self.assertRaisesRegex(ValueError, 'sequence_stride must be '):
|
||||
_ = timeseries.timeseries_dataset(
|
||||
np.arange(10), None, 3, sequence_stride=0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user