Create text_dataset_from_directory utility and refactor shared code between text_dataset_from_directory and image_dataset_from_directory.

PiperOrigin-RevId: 307844332
Change-Id: I4b5f094dfa98a71d5ffa59c3134794bce765a575
This commit is contained in:
Francois Chollet 2020-04-22 10:25:47 -07:00 committed by TensorFlower Gardener
parent c3759a4130
commit 09e67c47ee
8 changed files with 661 additions and 141 deletions

View File

@ -31,6 +31,7 @@ py_library(
py_library(
name = "image",
srcs = [
"dataset_utils.py",
"image.py",
"image_dataset.py",
],
@ -69,7 +70,9 @@ py_library(
py_library(
name = "text",
srcs = [
"dataset_utils.py",
"text.py",
"text_dataset.py",
],
deps = ["//tensorflow/python:util"],
)
@ -124,6 +127,19 @@ tf_py_test(
],
)
tf_py_test(
name = "text_dataset_test",
size = "small",
srcs = ["text_dataset_test.py"],
python_version = "PY3",
deps = [
":text",
"//tensorflow/python:client_testlib",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/keras",
],
)
tf_py_test(
name = "timeseries_test",
size = "small",

View File

@ -0,0 +1,201 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras image dataset loading utilities."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
def index_directory(directory,
labels,
formats,
class_names=None,
shuffle=True,
seed=None,
follow_links=False):
"""Make list of all files in the subdirs of `directory`, with their labels.
Args:
directory: The target directory (string).
labels: Either "inferred"
(labels are generated from the directory structure),
or a list/tuple of integer labels of the same size as the number of
valid files found in the directory. Labels should be sorted according
to the alphanumeric order of the image file paths
(obtained via `os.walk(directory)` in Python).
formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt").
class_names: Only valid if "labels" is "inferred". This is the explict
list of class names (must match names of subdirectories). Used
to control the order of the classes
(otherwise alphanumerical order is used).
shuffle: Whether to shuffle the data. Default: True.
If set to False, sorts the data in alphanumeric order.
seed: Optional random seed for shuffling.
follow_links: Whether to visits subdirectories pointed to by symlinks.
Returns:
tuple (file_paths, labels, class_names).
file_paths: list of file paths (strings).
labels: list of matching integer labels (same length as file_paths)
class_names: names of the classes corresponding to these labels, in order.
"""
inferred_class_names = []
for subdir in sorted(os.listdir(directory)):
if os.path.isdir(os.path.join(directory, subdir)):
inferred_class_names.append(subdir)
if not class_names:
class_names = inferred_class_names
else:
if set(class_names) != set(inferred_class_names):
raise ValueError(
'The `class_names` passed did not match the '
'names of the subdirectories of the target directory. '
'Expected: %s, but received: %s' %
(inferred_class_names, class_names))
class_indices = dict(zip(class_names, range(len(class_names))))
# Build an index of the files
# in the different class subfolders.
pool = multiprocessing.pool.ThreadPool()
results = []
filenames = []
for dirpath in (os.path.join(directory, subdir) for subdir in class_names):
results.append(
pool.apply_async(index_subdirectory,
(dirpath, class_indices, follow_links, formats)))
labels_list = []
for res in results:
partial_filenames, partial_labels = res.get()
labels_list.append(partial_labels)
filenames += partial_filenames
if labels != 'inferred':
if len(labels) != len(filenames):
raise ValueError('Expected the lengths of `labels` to match the number '
'of files in the target directory. len(labels) is %s '
'while we found %s files in %s.' % (
len(labels), len(filenames), directory))
else:
i = 0
labels = np.zeros((len(filenames),), dtype='int32')
for partial_labels in labels_list:
labels[i:i + len(partial_labels)] = partial_labels
i += len(partial_labels)
print('Found %d files belonging to %d classes.' %
(len(filenames), len(class_names)))
pool.close()
pool.join()
file_paths = [os.path.join(directory, fname) for fname in filenames]
if shuffle:
# Shuffle globally to erase macro-structure
if seed is None:
seed = np.random.randint(1e6)
rng = np.random.RandomState(seed)
rng.shuffle(file_paths)
rng = np.random.RandomState(seed)
rng.shuffle(labels)
return file_paths, labels, class_names
def iter_valid_files(directory, follow_links, formats):
walk = os.walk(directory, followlinks=follow_links)
for root, _, files in sorted(walk, key=lambda x: x[0]):
for fname in sorted(files):
if fname.lower().endswith(formats):
yield root, fname
def index_subdirectory(directory, class_indices, follow_links, formats):
"""Recursively walks directory and list image paths and their class index.
Arguments:
directory: string, target directory.
class_indices: dict mapping class names to their index.
follow_links: boolean, whether to recursively follow subdirectories
(if False, we only list top-level images in `directory`).
formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt").
Returns:
tuple `(filenames, labels)`. `filenames` is a list of relative file
paths, and `labels` is a list of integer labels corresponding to these
files.
"""
dirname = os.path.basename(directory)
valid_files = iter_valid_files(directory, follow_links, formats)
labels = []
filenames = []
for root, fname in valid_files:
labels.append(class_indices[dirname])
absolute_path = os.path.join(root, fname)
relative_path = os.path.join(
dirname, os.path.relpath(absolute_path, directory))
filenames.append(relative_path)
return filenames, labels
def get_training_or_validation_split(samples, labels, validation_split, subset):
"""Potentially restict samples & labels to a training or validation split.
Args:
samples: List of elements.
labels: List of corresponding labels.
validation_split: Float, fraction of data to reserve for validation.
subset: Subset of the data to return.
Either "training", "validation", or None. If None, we return all of the
data.
Returns:
tuple (samples, labels), potentially restricted to the specified subset.
"""
if validation_split:
if not 0 < validation_split < 1:
raise ValueError(
'`validation_split` must be between 0 and 1, received: %s' %
(validation_split,))
if subset is None:
return samples, labels
num_val_samples = int(validation_split * len(samples))
if subset == 'training':
samples = samples[:-num_val_samples]
labels = labels[:-num_val_samples]
elif subset == 'validation':
samples = samples[-num_val_samples:]
labels = labels[-num_val_samples:]
else:
raise ValueError('`subset` must be either "training" '
'or "validation", received: %s' % (subset,))
return samples, labels
def labels_to_dataset(labels, label_mode, num_classes):
label_ds = dataset_ops.Dataset.from_tensor_slices(labels)
if label_mode == 'binary':
label_ds = label_ds.map(
lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1))
elif label_mode == 'categorical':
label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
return label_ds

View File

@ -18,17 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.ops import array_ops
from tensorflow.python.keras.preprocessing import dataset_utils
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import keras_export
@ -49,7 +45,7 @@ def image_dataset_from_directory(directory,
subset=None,
interpolation='bilinear',
follow_links=False):
"""Generates a Dataset from image files in a directory.
"""Generates a `tf.data.Dataset` from image files in a directory.
If your directory structure is:
@ -63,10 +59,10 @@ def image_dataset_from_directory(directory,
......b_image_2.jpg
```
Then calling `from_directory(main_directory, labels='inferred')`
will return a Dataset that yields batches of images from
Then calling `image_dataset_from_directory(main_directory, labels='inferred')`
will return a `tf.data.Dataset` that yields batches of images from
the subdirectories `class_a` and `class_b`, together with labels
0 and 1 (0 corresponding to class_a and 1 corresponding to class_b).
0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).
Supported image formats: jpeg, png, bmp, gif.
Animated gifs are truncated to the first frame.
@ -126,22 +122,22 @@ def image_dataset_from_directory(directory,
has shape `(batch_size, image_size[0], image_size[1], num_channels)`,
and `labels` follows the format described below.
Rules regarding labels format:
- if `label_mode` is `int`, the labels are an `int32` tensor of shape
`(batch_size,)`.
- if `label_mode` is `binary`, the labels are a `float32` tensor of
1s and 0s of shape `(batch_size, 1)`.
- if `label_mode` is `categorial`, the labels are a `float32` tensor
of shape `(batch_size, num_classes)`, representing a one-hot
encoding of the class index.
Rules regarding labels format:
- if `label_mode` is `int`, the labels are an `int32` tensor of shape
`(batch_size,)`.
- if `label_mode` is `binary`, the labels are a `float32` tensor of
1s and 0s of shape `(batch_size, 1)`.
- if `label_mode` is `categorial`, the labels are a `float32` tensor
of shape `(batch_size, num_classes)`, representing a one-hot
encoding of the class index.
Rules regarding number of channels in the yielded images:
- if `color_mode` is `grayscale`,
there's 1 channel in the image tensors.
- if `color_mode` is `rgb`,
there are 3 channel in the image tensors.
- if `color_mode` is `rgba`,
there are 4 channel in the image tensors.
Rules regarding number of channels in the yielded images:
- if `color_mode` is `grayscale`,
there's 1 channel in the image tensors.
- if `color_mode` is `rgb`,
there are 3 channel in the image tensors.
- if `color_mode` is `rgba`,
there are 4 channel in the image tensors.
"""
if labels != 'inferred':
if not isinstance(labels, (list, tuple)):
@ -172,85 +168,25 @@ def image_dataset_from_directory(directory,
'Received: %s' % (color_mode,))
interpolation = image_preprocessing.get_interpolation(interpolation)
inferred_class_names = []
for subdir in sorted(os.listdir(directory)):
if os.path.isdir(os.path.join(directory, subdir)):
inferred_class_names.append(subdir)
if not class_names:
class_names = inferred_class_names
else:
if set(class_names) != set(inferred_class_names):
raise ValueError(
'The `class_names` passed did not match the '
'names of the subdirectories of the target directory. '
'Expected: %s, but received: %s' %
(inferred_class_names, class_names))
class_indices = dict(zip(class_names, range(len(class_names))))
if seed is None:
seed = np.random.randint(1e6)
image_paths, labels, class_names = dataset_utils.index_directory(
directory,
labels,
formats=WHITELIST_FORMATS,
class_names=class_names,
shuffle=shuffle,
seed=seed,
follow_links=follow_links)
if label_mode == 'binary' and len(class_names) != 2:
raise ValueError(
'When passing `label_mode="binary", there must exactly 2 classes. '
'Found the following classes: %s' % (class_names,))
# Build an index of the images
# in the different class subfolders.
pool = multiprocessing.pool.ThreadPool()
results = []
filenames = []
for dirpath in (os.path.join(directory, subdir) for subdir in class_names):
results.append(
pool.apply_async(list_labeled_images_in_directory,
(dirpath, class_indices, follow_links)))
labels_list = []
for res in results:
partial_labels, partial_filenames = res.get()
labels_list.append(partial_labels)
filenames += partial_filenames
if labels != 'inferred':
if len(labels) != len(filenames):
raise ValueError('Expected the lengths of `labels` to match the number '
'of images in the target directory. len(labels) is %s '
'while we found %s images in %s.' % (
len(labels), len(filenames), directory))
else:
i = 0
labels = np.zeros((len(filenames),), dtype='int32')
for partial_labels in labels_list:
labels[i:i + len(partial_labels)] = partial_labels
i += len(partial_labels)
image_paths, labels = dataset_utils.get_training_or_validation_split(
image_paths, labels, validation_split, subset)
print('Found %d images belonging to %d classes.' %
(len(filenames), len(class_names)))
pool.close()
pool.join()
image_paths = [os.path.join(directory, fname) for fname in filenames]
if shuffle:
# Shuffle globally to erase macro-structure
# (the dataset will be further shuffled within a local buffer
# at each iteration)
if seed is None:
seed = np.random.randint(1e6)
rng = np.random.RandomState(seed)
rng.shuffle(image_paths)
rng = np.random.RandomState(seed)
rng.shuffle(labels)
if validation_split:
if not 0 < validation_split < 1:
raise ValueError(
'`validation_split` must be between 0 and 1, received: %s' %
(validation_split,))
num_val_samples = int(validation_split * len(image_paths))
if subset == 'training':
image_paths = image_paths[:-num_val_samples]
labels = labels[:-num_val_samples]
elif subset == 'validation':
image_paths = image_paths[-num_val_samples:]
labels = labels[-num_val_samples:]
else:
raise ValueError('`subset` must be either "training" '
'or "validation", received: %s' % (subset,))
dataset = paths_and_labels_to_dataset(
image_paths=image_paths,
image_size=image_size,
@ -263,6 +199,8 @@ def image_dataset_from_directory(directory,
# Shuffle locally at each iteration
dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
dataset = dataset.batch(batch_size)
# Users may need to reference `class_names`.
dataset.class_names = class_names
return dataset
@ -279,51 +217,11 @@ def paths_and_labels_to_dataset(image_paths,
img_ds = path_ds.map(
lambda x: path_to_image(x, image_size, num_channels, interpolation))
if label_mode:
label_ds = dataset_ops.Dataset.from_tensor_slices(labels)
if label_mode == 'binary':
label_ds = label_ds.map(
lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1))
elif label_mode == 'categorical':
label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes)
img_ds = dataset_ops.Dataset.zip((img_ds, label_ds))
return img_ds
def iter_valid_files(directory, follow_links):
walk = os.walk(directory, followlinks=follow_links)
for root, _, files in sorted(walk, key=lambda x: x[0]):
for fname in sorted(files):
if fname.lower().endswith(WHITELIST_FORMATS):
yield root, fname
def list_labeled_images_in_directory(directory, class_indices, follow_links):
"""Recursively walks directory and list image paths and their class index.
Arguments:
directory: string, target directory.
class_indices: dict mapping class names to their index.
follow_links: boolean, whether to recursively follow subdirectories
(if False, we only list top-level images in `directory`).
Returns:
tuple `(labels, filenames)`. `labels` is a list of integer
labels and `filenames` is a list of relative image paths corresponding
to these labels.
"""
dirname = os.path.basename(directory)
valid_files = iter_valid_files(directory, follow_links)
labels = []
filenames = []
for root, fname in valid_files:
labels.append(class_indices[dirname])
absolute_path = os.path.join(root, fname)
relative_path = os.path.join(
dirname, os.path.relpath(absolute_path, directory))
filenames.append(relative_path)
return labels, filenames
def path_to_image(path, image_size, num_channels, interpolation):
img = io_ops.read_file(path)
img = image_ops.decode_image(

View File

@ -35,7 +35,7 @@ except ImportError:
PIL = None
class DatasetFromDirectoryTest(keras_parameterized.TestCase):
class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
def _get_images(self, count=16, color_mode='rgb'):
width = height = 24
@ -262,7 +262,7 @@ class DatasetFromDirectoryTest(keras_parameterized.TestCase):
with self.assertRaisesRegex(
ValueError,
'Expected the lengths of `labels` to match the number of images'):
'Expected the lengths of `labels` to match the number of files'):
_ = image_dataset.image_dataset_from_directory(
directory, labels=[0, 0, 1, 1])

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from keras_preprocessing import text
from tensorflow.python.keras.preprocessing.text_dataset import text_dataset_from_directory # pylint: disable=unused-import
from tensorflow.python.util.tf_export import keras_export
hashing_trick = text.hashing_trick

View File

@ -0,0 +1,186 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras text dataset generation utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.keras.preprocessing import dataset_utils
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import string_ops
def text_dataset_from_directory(directory,
labels='inferred',
label_mode='int',
class_names=None,
batch_size=32,
max_length=None,
shuffle=True,
seed=None,
validation_split=None,
subset=None,
follow_links=False):
"""Generates a `tf.data.Dataset` from text files in a directory.
If your directory structure is:
```
main_directory/
...class_a/
......a_text_1.txt
......a_text_2.txt
...class_b/
......b_text_1.txt
......b_text_2.txt
```
Then calling `text_dataset_from_directory(main_directory, labels='inferred')`
will return a `tf.data.Dataset` that yields batches of texts from
the subdirectories `class_a` and `class_b`, together with labels
0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).
Only `.txt` files are supported at this time.
Arguments:
directory: Directory where the data is located.
If `labels` is "inferred", it should contain
subdirectories, each containing text files for a class.
Otherwise, the directory structure is ignored.
labels: Either "inferred"
(labels are generated from the directory structure),
or a list/tuple of integer labels of the same size as the number of
text files found in the directory. Labels should be sorted according
to the alphanumeric order of the text file paths
(obtained via `os.walk(directory)` in Python).
label_mode:
- 'int': means that the labels are encoded as integers
(e.g. for `sparse_categorical_crossentropy` loss).
- 'categorical' means that the labels are
encoded as a categorical vector
(e.g. for `categorical_crossentropy` loss).
- 'binary' means that the labels (there can be only 2)
are encoded as `float32` scalars with values 0 or 1
(e.g. for `binary_crossentropy`).
- None (no labels).
class_names: Only valid if "labels" is "inferred". This is the explict
list of class names (must match names of subdirectories). Used
to control the order of the classes
(otherwise alphanumerical order is used).
batch_size: Size of the batches of data. Default: 32.
max_length: Maximum size of a text string. Texts longer than this will
be truncated to `max_length`.
shuffle: Whether to shuffle the data. Default: True.
If set to False, sorts the data in alphanumeric order.
seed: Optional random seed for shuffling and transformations.
validation_split: Optional float between 0 and 1,
fraction of data to reserve for validation.
subset: One of "training" or "validation".
Only used if `validation_split` is set.
follow_links: Whether to visits subdirectories pointed to by symlinks.
Defaults to False.
Returns:
A `tf.data.Dataset` object.
- If `label_mode` is None, it yields `string` tensors of shape
`(batch_size,)`, containing the contents of a batch of text files.
- Otherwise, it yields a tuple `(texts, labels)`, where `texts`
has shape `(batch_size,)` and `labels` follows the format described
below.
Rules regarding labels format:
- if `label_mode` is `int`, the labels are an `int32` tensor of shape
`(batch_size,)`.
- if `label_mode` is `binary`, the labels are a `float32` tensor of
1s and 0s of shape `(batch_size, 1)`.
- if `label_mode` is `categorial`, the labels are a `float32` tensor
of shape `(batch_size, num_classes)`, representing a one-hot
encoding of the class index.
"""
if labels != 'inferred':
if not isinstance(labels, (list, tuple)):
raise ValueError(
'`labels` argument should be a list/tuple of integer labels, of '
'the same size as the number of text files in the target '
'directory. If you wish to infer the labels from the subdirectory '
'names in the target directory, pass `labels="inferred"`. '
'If you wish to get a dataset that only contains text samples '
'(no labels), pass `labels=None`.')
if class_names:
raise ValueError('You can only pass `class_names` if the labels are '
'inferred from the subdirectory names in the target '
'directory (`labels="inferred"`).')
if label_mode not in {'int', 'categorical', 'binary', None}:
raise ValueError(
'`label_mode` argument must be one of "int", "categorical", "binary", '
'or None. Received: %s' % (label_mode,))
if seed is None:
seed = np.random.randint(1e6)
file_paths, labels, class_names = dataset_utils.index_directory(
directory,
labels,
formats=('.txt',),
class_names=class_names,
shuffle=shuffle,
seed=seed,
follow_links=follow_links)
if label_mode == 'binary' and len(class_names) != 2:
raise ValueError(
'When passing `label_mode="binary", there must exactly 2 classes. '
'Found the following classes: %s' % (class_names,))
file_paths, labels = dataset_utils.get_training_or_validation_split(
file_paths, labels, validation_split, subset)
dataset = paths_and_labels_to_dataset(
file_paths=file_paths,
labels=labels,
label_mode=label_mode,
num_classes=len(class_names),
max_length=max_length)
if shuffle:
# Shuffle locally at each iteration
dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
dataset = dataset.batch(batch_size)
# Users may need to reference `class_names`.
dataset.class_names = class_names
return dataset
def paths_and_labels_to_dataset(file_paths,
labels,
label_mode,
num_classes,
max_length):
"""Constructs a dataset of text strings and labels."""
path_ds = dataset_ops.Dataset.from_tensor_slices(file_paths)
string_ds = path_ds.map(
lambda x: path_to_string_content(x, max_length))
if label_mode:
label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes)
string_ds = dataset_ops.Dataset.zip((string_ds, label_ds))
return string_ds
def path_to_string_content(path, max_length):
txt = io_ops.read_file(path)
if max_length is not None:
txt = string_ops.substr(txt, 0, max_length)
return txt

View File

@ -0,0 +1,218 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for text_dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import shutil
import string
from tensorflow.python.compat import v2_compat
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.preprocessing import text_dataset
from tensorflow.python.platform import test
class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
def _prepare_directory(self,
num_classes=2,
nested_dirs=False,
count=16,
length=20):
# Get a unique temp directory
temp_dir = os.path.join(self.get_temp_dir(), str(random.randint(0, 1e6)))
os.mkdir(temp_dir)
self.addCleanup(shutil.rmtree, temp_dir)
# Generate paths to class subdirectories
paths = []
for class_index in range(num_classes):
class_directory = 'class_%s' % (class_index,)
if nested_dirs:
class_paths = [
class_directory, os.path.join(class_directory, 'subfolder_1'),
os.path.join(class_directory, 'subfolder_2'), os.path.join(
class_directory, 'subfolder_1', 'sub-subfolder')
]
else:
class_paths = [class_directory]
for path in class_paths:
os.mkdir(os.path.join(temp_dir, path))
paths += class_paths
for i in range(count):
path = paths[count % len(paths)]
filename = os.path.join(path, 'text_%s.txt' % (i,))
f = open(os.path.join(temp_dir, filename), 'w')
text = ''.join([random.choice(string.printable) for _ in range(length)])
f.write(text)
f.close()
return temp_dir
def test_text_dataset_from_directory_binary(self):
directory = self._prepare_directory(num_classes=2)
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode='int', max_length=10)
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (8,))
self.assertEqual(batch[0].dtype.name, 'string')
self.assertEqual(len(batch[0].numpy()[0]), 10) # Test max_length
self.assertEqual(batch[1].shape, (8,))
self.assertEqual(batch[1].dtype.name, 'int32')
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode='binary')
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (8,))
self.assertEqual(batch[0].dtype.name, 'string')
self.assertEqual(batch[1].shape, (8, 1))
self.assertEqual(batch[1].dtype.name, 'float32')
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode='categorical')
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (8,))
self.assertEqual(batch[0].dtype.name, 'string')
self.assertEqual(batch[1].shape, (8, 2))
self.assertEqual(batch[1].dtype.name, 'float32')
def test_sample_count(self):
directory = self._prepare_directory(num_classes=4, count=15)
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode=None)
sample_count = 0
for batch in dataset:
sample_count += batch.shape[0]
self.assertEqual(sample_count, 15)
def test_text_dataset_from_directory_multiclass(self):
directory = self._prepare_directory(num_classes=4, count=15)
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode=None)
batch = next(iter(dataset))
self.assertEqual(batch.shape, (8,))
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode=None)
sample_count = 0
iterator = iter(dataset)
for batch in dataset:
sample_count += next(iterator).shape[0]
self.assertEqual(sample_count, 15)
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode='int')
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (8,))
self.assertEqual(batch[0].dtype.name, 'string')
self.assertEqual(batch[1].shape, (8,))
self.assertEqual(batch[1].dtype.name, 'int32')
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode='categorical')
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (8,))
self.assertEqual(batch[0].dtype.name, 'string')
self.assertEqual(batch[1].shape, (8, 4))
self.assertEqual(batch[1].dtype.name, 'float32')
def test_text_dataset_from_directory_validation_split(self):
directory = self._prepare_directory(num_classes=2, count=10)
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=10, validation_split=0.2, subset='training')
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (8,))
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=10, validation_split=0.2, subset='validation')
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (2,))
def test_text_dataset_from_directory_manual_labels(self):
directory = self._prepare_directory(num_classes=2, count=2)
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, labels=[0, 1], shuffle=False)
batch = next(iter(dataset))
self.assertLen(batch, 2)
self.assertAllClose(batch[1], [0, 1])
def test_text_dataset_from_directory_follow_links(self):
directory = self._prepare_directory(num_classes=2, count=25,
nested_dirs=True)
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=8, label_mode=None, follow_links=True)
sample_count = 0
for batch in dataset:
sample_count += batch.shape[0]
self.assertEqual(sample_count, 25)
def test_text_dataset_from_directory_errors(self):
directory = self._prepare_directory(num_classes=3, count=5)
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
_ = text_dataset.text_dataset_from_directory(
directory, labels=None)
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
_ = text_dataset.text_dataset_from_directory(
directory, label_mode='other')
with self.assertRaisesRegex(
ValueError, 'only pass `class_names` if the labels are inferred'):
_ = text_dataset.text_dataset_from_directory(
directory, labels=[0, 0, 1, 1, 1],
class_names=['class_0', 'class_1', 'class_2'])
with self.assertRaisesRegex(
ValueError,
'Expected the lengths of `labels` to match the number of files'):
_ = text_dataset.text_dataset_from_directory(
directory, labels=[0, 0, 1, 1])
with self.assertRaisesRegex(
ValueError, '`class_names` passed did not match'):
_ = text_dataset.text_dataset_from_directory(
directory, class_names=['class_0', 'class_2'])
with self.assertRaisesRegex(ValueError, 'there must exactly 2 classes'):
_ = text_dataset.text_dataset_from_directory(
directory, label_mode='binary')
with self.assertRaisesRegex(ValueError,
'`validation_split` must be between 0 and 1'):
_ = text_dataset.text_dataset_from_directory(
directory, validation_split=2)
with self.assertRaisesRegex(ValueError,
'`subset` must be either "training" or'):
_ = text_dataset.text_dataset_from_directory(
directory, validation_split=0.2, subset='other')
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()

View File

@ -106,8 +106,8 @@ def timeseries_dataset_from_array(
```python
input_data = data[:-10]
targets = data[10:]
dataset = tf.keras.preprocessing.timeseries.dataset_from_array(
input_data, targets, sequence_length=10)
dataset = tf.keras.preprocessing.timeseries_dataset_from_array(
input_data, targets, sequence_length=10)
for batch in dataset:
inputs, targets = batch
assert np.array_equal(inputs[0], data[:10]) # First sequence: steps [0-9]