Create text_dataset_from_directory
utility and refactor shared code between text_dataset_from_directory
and image_dataset_from_directory
.
PiperOrigin-RevId: 307844332 Change-Id: I4b5f094dfa98a71d5ffa59c3134794bce765a575
This commit is contained in:
parent
c3759a4130
commit
09e67c47ee
@ -31,6 +31,7 @@ py_library(
|
||||
py_library(
|
||||
name = "image",
|
||||
srcs = [
|
||||
"dataset_utils.py",
|
||||
"image.py",
|
||||
"image_dataset.py",
|
||||
],
|
||||
@ -69,7 +70,9 @@ py_library(
|
||||
py_library(
|
||||
name = "text",
|
||||
srcs = [
|
||||
"dataset_utils.py",
|
||||
"text.py",
|
||||
"text_dataset.py",
|
||||
],
|
||||
deps = ["//tensorflow/python:util"],
|
||||
)
|
||||
@ -124,6 +127,19 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "text_dataset_test",
|
||||
size = "small",
|
||||
srcs = ["text_dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":text",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/keras",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "timeseries_test",
|
||||
size = "small",
|
||||
|
201
tensorflow/python/keras/preprocessing/dataset_utils.py
Normal file
201
tensorflow/python/keras/preprocessing/dataset_utils.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras image dataset loading utilities."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def index_directory(directory,
|
||||
labels,
|
||||
formats,
|
||||
class_names=None,
|
||||
shuffle=True,
|
||||
seed=None,
|
||||
follow_links=False):
|
||||
"""Make list of all files in the subdirs of `directory`, with their labels.
|
||||
|
||||
Args:
|
||||
directory: The target directory (string).
|
||||
labels: Either "inferred"
|
||||
(labels are generated from the directory structure),
|
||||
or a list/tuple of integer labels of the same size as the number of
|
||||
valid files found in the directory. Labels should be sorted according
|
||||
to the alphanumeric order of the image file paths
|
||||
(obtained via `os.walk(directory)` in Python).
|
||||
formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt").
|
||||
class_names: Only valid if "labels" is "inferred". This is the explict
|
||||
list of class names (must match names of subdirectories). Used
|
||||
to control the order of the classes
|
||||
(otherwise alphanumerical order is used).
|
||||
shuffle: Whether to shuffle the data. Default: True.
|
||||
If set to False, sorts the data in alphanumeric order.
|
||||
seed: Optional random seed for shuffling.
|
||||
follow_links: Whether to visits subdirectories pointed to by symlinks.
|
||||
|
||||
Returns:
|
||||
tuple (file_paths, labels, class_names).
|
||||
file_paths: list of file paths (strings).
|
||||
labels: list of matching integer labels (same length as file_paths)
|
||||
class_names: names of the classes corresponding to these labels, in order.
|
||||
"""
|
||||
inferred_class_names = []
|
||||
for subdir in sorted(os.listdir(directory)):
|
||||
if os.path.isdir(os.path.join(directory, subdir)):
|
||||
inferred_class_names.append(subdir)
|
||||
if not class_names:
|
||||
class_names = inferred_class_names
|
||||
else:
|
||||
if set(class_names) != set(inferred_class_names):
|
||||
raise ValueError(
|
||||
'The `class_names` passed did not match the '
|
||||
'names of the subdirectories of the target directory. '
|
||||
'Expected: %s, but received: %s' %
|
||||
(inferred_class_names, class_names))
|
||||
class_indices = dict(zip(class_names, range(len(class_names))))
|
||||
|
||||
# Build an index of the files
|
||||
# in the different class subfolders.
|
||||
pool = multiprocessing.pool.ThreadPool()
|
||||
results = []
|
||||
filenames = []
|
||||
for dirpath in (os.path.join(directory, subdir) for subdir in class_names):
|
||||
results.append(
|
||||
pool.apply_async(index_subdirectory,
|
||||
(dirpath, class_indices, follow_links, formats)))
|
||||
labels_list = []
|
||||
for res in results:
|
||||
partial_filenames, partial_labels = res.get()
|
||||
labels_list.append(partial_labels)
|
||||
filenames += partial_filenames
|
||||
if labels != 'inferred':
|
||||
if len(labels) != len(filenames):
|
||||
raise ValueError('Expected the lengths of `labels` to match the number '
|
||||
'of files in the target directory. len(labels) is %s '
|
||||
'while we found %s files in %s.' % (
|
||||
len(labels), len(filenames), directory))
|
||||
else:
|
||||
i = 0
|
||||
labels = np.zeros((len(filenames),), dtype='int32')
|
||||
for partial_labels in labels_list:
|
||||
labels[i:i + len(partial_labels)] = partial_labels
|
||||
i += len(partial_labels)
|
||||
|
||||
print('Found %d files belonging to %d classes.' %
|
||||
(len(filenames), len(class_names)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
file_paths = [os.path.join(directory, fname) for fname in filenames]
|
||||
|
||||
if shuffle:
|
||||
# Shuffle globally to erase macro-structure
|
||||
if seed is None:
|
||||
seed = np.random.randint(1e6)
|
||||
rng = np.random.RandomState(seed)
|
||||
rng.shuffle(file_paths)
|
||||
rng = np.random.RandomState(seed)
|
||||
rng.shuffle(labels)
|
||||
return file_paths, labels, class_names
|
||||
|
||||
|
||||
def iter_valid_files(directory, follow_links, formats):
|
||||
walk = os.walk(directory, followlinks=follow_links)
|
||||
for root, _, files in sorted(walk, key=lambda x: x[0]):
|
||||
for fname in sorted(files):
|
||||
if fname.lower().endswith(formats):
|
||||
yield root, fname
|
||||
|
||||
|
||||
def index_subdirectory(directory, class_indices, follow_links, formats):
|
||||
"""Recursively walks directory and list image paths and their class index.
|
||||
|
||||
Arguments:
|
||||
directory: string, target directory.
|
||||
class_indices: dict mapping class names to their index.
|
||||
follow_links: boolean, whether to recursively follow subdirectories
|
||||
(if False, we only list top-level images in `directory`).
|
||||
formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt").
|
||||
|
||||
Returns:
|
||||
tuple `(filenames, labels)`. `filenames` is a list of relative file
|
||||
paths, and `labels` is a list of integer labels corresponding to these
|
||||
files.
|
||||
"""
|
||||
dirname = os.path.basename(directory)
|
||||
valid_files = iter_valid_files(directory, follow_links, formats)
|
||||
labels = []
|
||||
filenames = []
|
||||
for root, fname in valid_files:
|
||||
labels.append(class_indices[dirname])
|
||||
absolute_path = os.path.join(root, fname)
|
||||
relative_path = os.path.join(
|
||||
dirname, os.path.relpath(absolute_path, directory))
|
||||
filenames.append(relative_path)
|
||||
return filenames, labels
|
||||
|
||||
|
||||
def get_training_or_validation_split(samples, labels, validation_split, subset):
|
||||
"""Potentially restict samples & labels to a training or validation split.
|
||||
|
||||
Args:
|
||||
samples: List of elements.
|
||||
labels: List of corresponding labels.
|
||||
validation_split: Float, fraction of data to reserve for validation.
|
||||
subset: Subset of the data to return.
|
||||
Either "training", "validation", or None. If None, we return all of the
|
||||
data.
|
||||
|
||||
Returns:
|
||||
tuple (samples, labels), potentially restricted to the specified subset.
|
||||
"""
|
||||
if validation_split:
|
||||
if not 0 < validation_split < 1:
|
||||
raise ValueError(
|
||||
'`validation_split` must be between 0 and 1, received: %s' %
|
||||
(validation_split,))
|
||||
if subset is None:
|
||||
return samples, labels
|
||||
|
||||
num_val_samples = int(validation_split * len(samples))
|
||||
if subset == 'training':
|
||||
samples = samples[:-num_val_samples]
|
||||
labels = labels[:-num_val_samples]
|
||||
elif subset == 'validation':
|
||||
samples = samples[-num_val_samples:]
|
||||
labels = labels[-num_val_samples:]
|
||||
else:
|
||||
raise ValueError('`subset` must be either "training" '
|
||||
'or "validation", received: %s' % (subset,))
|
||||
return samples, labels
|
||||
|
||||
|
||||
def labels_to_dataset(labels, label_mode, num_classes):
|
||||
label_ds = dataset_ops.Dataset.from_tensor_slices(labels)
|
||||
if label_mode == 'binary':
|
||||
label_ds = label_ds.map(
|
||||
lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1))
|
||||
elif label_mode == 'categorical':
|
||||
label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
|
||||
return label_ds
|
@ -18,17 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.keras.preprocessing import dataset_utils
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -49,7 +45,7 @@ def image_dataset_from_directory(directory,
|
||||
subset=None,
|
||||
interpolation='bilinear',
|
||||
follow_links=False):
|
||||
"""Generates a Dataset from image files in a directory.
|
||||
"""Generates a `tf.data.Dataset` from image files in a directory.
|
||||
|
||||
If your directory structure is:
|
||||
|
||||
@ -63,10 +59,10 @@ def image_dataset_from_directory(directory,
|
||||
......b_image_2.jpg
|
||||
```
|
||||
|
||||
Then calling `from_directory(main_directory, labels='inferred')`
|
||||
will return a Dataset that yields batches of images from
|
||||
Then calling `image_dataset_from_directory(main_directory, labels='inferred')`
|
||||
will return a `tf.data.Dataset` that yields batches of images from
|
||||
the subdirectories `class_a` and `class_b`, together with labels
|
||||
0 and 1 (0 corresponding to class_a and 1 corresponding to class_b).
|
||||
0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).
|
||||
|
||||
Supported image formats: jpeg, png, bmp, gif.
|
||||
Animated gifs are truncated to the first frame.
|
||||
@ -126,22 +122,22 @@ def image_dataset_from_directory(directory,
|
||||
has shape `(batch_size, image_size[0], image_size[1], num_channels)`,
|
||||
and `labels` follows the format described below.
|
||||
|
||||
Rules regarding labels format:
|
||||
- if `label_mode` is `int`, the labels are an `int32` tensor of shape
|
||||
`(batch_size,)`.
|
||||
- if `label_mode` is `binary`, the labels are a `float32` tensor of
|
||||
1s and 0s of shape `(batch_size, 1)`.
|
||||
- if `label_mode` is `categorial`, the labels are a `float32` tensor
|
||||
of shape `(batch_size, num_classes)`, representing a one-hot
|
||||
encoding of the class index.
|
||||
Rules regarding labels format:
|
||||
- if `label_mode` is `int`, the labels are an `int32` tensor of shape
|
||||
`(batch_size,)`.
|
||||
- if `label_mode` is `binary`, the labels are a `float32` tensor of
|
||||
1s and 0s of shape `(batch_size, 1)`.
|
||||
- if `label_mode` is `categorial`, the labels are a `float32` tensor
|
||||
of shape `(batch_size, num_classes)`, representing a one-hot
|
||||
encoding of the class index.
|
||||
|
||||
Rules regarding number of channels in the yielded images:
|
||||
- if `color_mode` is `grayscale`,
|
||||
there's 1 channel in the image tensors.
|
||||
- if `color_mode` is `rgb`,
|
||||
there are 3 channel in the image tensors.
|
||||
- if `color_mode` is `rgba`,
|
||||
there are 4 channel in the image tensors.
|
||||
Rules regarding number of channels in the yielded images:
|
||||
- if `color_mode` is `grayscale`,
|
||||
there's 1 channel in the image tensors.
|
||||
- if `color_mode` is `rgb`,
|
||||
there are 3 channel in the image tensors.
|
||||
- if `color_mode` is `rgba`,
|
||||
there are 4 channel in the image tensors.
|
||||
"""
|
||||
if labels != 'inferred':
|
||||
if not isinstance(labels, (list, tuple)):
|
||||
@ -172,85 +168,25 @@ def image_dataset_from_directory(directory,
|
||||
'Received: %s' % (color_mode,))
|
||||
interpolation = image_preprocessing.get_interpolation(interpolation)
|
||||
|
||||
inferred_class_names = []
|
||||
for subdir in sorted(os.listdir(directory)):
|
||||
if os.path.isdir(os.path.join(directory, subdir)):
|
||||
inferred_class_names.append(subdir)
|
||||
if not class_names:
|
||||
class_names = inferred_class_names
|
||||
else:
|
||||
if set(class_names) != set(inferred_class_names):
|
||||
raise ValueError(
|
||||
'The `class_names` passed did not match the '
|
||||
'names of the subdirectories of the target directory. '
|
||||
'Expected: %s, but received: %s' %
|
||||
(inferred_class_names, class_names))
|
||||
class_indices = dict(zip(class_names, range(len(class_names))))
|
||||
if seed is None:
|
||||
seed = np.random.randint(1e6)
|
||||
image_paths, labels, class_names = dataset_utils.index_directory(
|
||||
directory,
|
||||
labels,
|
||||
formats=WHITELIST_FORMATS,
|
||||
class_names=class_names,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
follow_links=follow_links)
|
||||
|
||||
if label_mode == 'binary' and len(class_names) != 2:
|
||||
raise ValueError(
|
||||
'When passing `label_mode="binary", there must exactly 2 classes. '
|
||||
'Found the following classes: %s' % (class_names,))
|
||||
|
||||
# Build an index of the images
|
||||
# in the different class subfolders.
|
||||
pool = multiprocessing.pool.ThreadPool()
|
||||
results = []
|
||||
filenames = []
|
||||
for dirpath in (os.path.join(directory, subdir) for subdir in class_names):
|
||||
results.append(
|
||||
pool.apply_async(list_labeled_images_in_directory,
|
||||
(dirpath, class_indices, follow_links)))
|
||||
labels_list = []
|
||||
for res in results:
|
||||
partial_labels, partial_filenames = res.get()
|
||||
labels_list.append(partial_labels)
|
||||
filenames += partial_filenames
|
||||
if labels != 'inferred':
|
||||
if len(labels) != len(filenames):
|
||||
raise ValueError('Expected the lengths of `labels` to match the number '
|
||||
'of images in the target directory. len(labels) is %s '
|
||||
'while we found %s images in %s.' % (
|
||||
len(labels), len(filenames), directory))
|
||||
else:
|
||||
i = 0
|
||||
labels = np.zeros((len(filenames),), dtype='int32')
|
||||
for partial_labels in labels_list:
|
||||
labels[i:i + len(partial_labels)] = partial_labels
|
||||
i += len(partial_labels)
|
||||
image_paths, labels = dataset_utils.get_training_or_validation_split(
|
||||
image_paths, labels, validation_split, subset)
|
||||
|
||||
print('Found %d images belonging to %d classes.' %
|
||||
(len(filenames), len(class_names)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
image_paths = [os.path.join(directory, fname) for fname in filenames]
|
||||
|
||||
if shuffle:
|
||||
# Shuffle globally to erase macro-structure
|
||||
# (the dataset will be further shuffled within a local buffer
|
||||
# at each iteration)
|
||||
if seed is None:
|
||||
seed = np.random.randint(1e6)
|
||||
rng = np.random.RandomState(seed)
|
||||
rng.shuffle(image_paths)
|
||||
rng = np.random.RandomState(seed)
|
||||
rng.shuffle(labels)
|
||||
|
||||
if validation_split:
|
||||
if not 0 < validation_split < 1:
|
||||
raise ValueError(
|
||||
'`validation_split` must be between 0 and 1, received: %s' %
|
||||
(validation_split,))
|
||||
num_val_samples = int(validation_split * len(image_paths))
|
||||
if subset == 'training':
|
||||
image_paths = image_paths[:-num_val_samples]
|
||||
labels = labels[:-num_val_samples]
|
||||
elif subset == 'validation':
|
||||
image_paths = image_paths[-num_val_samples:]
|
||||
labels = labels[-num_val_samples:]
|
||||
else:
|
||||
raise ValueError('`subset` must be either "training" '
|
||||
'or "validation", received: %s' % (subset,))
|
||||
dataset = paths_and_labels_to_dataset(
|
||||
image_paths=image_paths,
|
||||
image_size=image_size,
|
||||
@ -263,6 +199,8 @@ def image_dataset_from_directory(directory,
|
||||
# Shuffle locally at each iteration
|
||||
dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
|
||||
dataset = dataset.batch(batch_size)
|
||||
# Users may need to reference `class_names`.
|
||||
dataset.class_names = class_names
|
||||
return dataset
|
||||
|
||||
|
||||
@ -279,51 +217,11 @@ def paths_and_labels_to_dataset(image_paths,
|
||||
img_ds = path_ds.map(
|
||||
lambda x: path_to_image(x, image_size, num_channels, interpolation))
|
||||
if label_mode:
|
||||
label_ds = dataset_ops.Dataset.from_tensor_slices(labels)
|
||||
if label_mode == 'binary':
|
||||
label_ds = label_ds.map(
|
||||
lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1))
|
||||
elif label_mode == 'categorical':
|
||||
label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
|
||||
label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes)
|
||||
img_ds = dataset_ops.Dataset.zip((img_ds, label_ds))
|
||||
return img_ds
|
||||
|
||||
|
||||
def iter_valid_files(directory, follow_links):
|
||||
walk = os.walk(directory, followlinks=follow_links)
|
||||
for root, _, files in sorted(walk, key=lambda x: x[0]):
|
||||
for fname in sorted(files):
|
||||
if fname.lower().endswith(WHITELIST_FORMATS):
|
||||
yield root, fname
|
||||
|
||||
|
||||
def list_labeled_images_in_directory(directory, class_indices, follow_links):
|
||||
"""Recursively walks directory and list image paths and their class index.
|
||||
|
||||
Arguments:
|
||||
directory: string, target directory.
|
||||
class_indices: dict mapping class names to their index.
|
||||
follow_links: boolean, whether to recursively follow subdirectories
|
||||
(if False, we only list top-level images in `directory`).
|
||||
|
||||
Returns:
|
||||
tuple `(labels, filenames)`. `labels` is a list of integer
|
||||
labels and `filenames` is a list of relative image paths corresponding
|
||||
to these labels.
|
||||
"""
|
||||
dirname = os.path.basename(directory)
|
||||
valid_files = iter_valid_files(directory, follow_links)
|
||||
labels = []
|
||||
filenames = []
|
||||
for root, fname in valid_files:
|
||||
labels.append(class_indices[dirname])
|
||||
absolute_path = os.path.join(root, fname)
|
||||
relative_path = os.path.join(
|
||||
dirname, os.path.relpath(absolute_path, directory))
|
||||
filenames.append(relative_path)
|
||||
return labels, filenames
|
||||
|
||||
|
||||
def path_to_image(path, image_size, num_channels, interpolation):
|
||||
img = io_ops.read_file(path)
|
||||
img = image_ops.decode_image(
|
||||
|
@ -35,7 +35,7 @@ except ImportError:
|
||||
PIL = None
|
||||
|
||||
|
||||
class DatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
|
||||
def _get_images(self, count=16, color_mode='rgb'):
|
||||
width = height = 24
|
||||
@ -262,7 +262,7 @@ class DatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Expected the lengths of `labels` to match the number of images'):
|
||||
'Expected the lengths of `labels` to match the number of files'):
|
||||
_ = image_dataset.image_dataset_from_directory(
|
||||
directory, labels=[0, 0, 1, 1])
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from keras_preprocessing import text
|
||||
|
||||
from tensorflow.python.keras.preprocessing.text_dataset import text_dataset_from_directory # pylint: disable=unused-import
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
hashing_trick = text.hashing_trick
|
||||
|
186
tensorflow/python/keras/preprocessing/text_dataset.py
Normal file
186
tensorflow/python/keras/preprocessing/text_dataset.py
Normal file
@ -0,0 +1,186 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras text dataset generation utilities."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.keras.preprocessing import dataset_utils
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
|
||||
|
||||
def text_dataset_from_directory(directory,
|
||||
labels='inferred',
|
||||
label_mode='int',
|
||||
class_names=None,
|
||||
batch_size=32,
|
||||
max_length=None,
|
||||
shuffle=True,
|
||||
seed=None,
|
||||
validation_split=None,
|
||||
subset=None,
|
||||
follow_links=False):
|
||||
"""Generates a `tf.data.Dataset` from text files in a directory.
|
||||
|
||||
If your directory structure is:
|
||||
|
||||
```
|
||||
main_directory/
|
||||
...class_a/
|
||||
......a_text_1.txt
|
||||
......a_text_2.txt
|
||||
...class_b/
|
||||
......b_text_1.txt
|
||||
......b_text_2.txt
|
||||
```
|
||||
|
||||
Then calling `text_dataset_from_directory(main_directory, labels='inferred')`
|
||||
will return a `tf.data.Dataset` that yields batches of texts from
|
||||
the subdirectories `class_a` and `class_b`, together with labels
|
||||
0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).
|
||||
|
||||
Only `.txt` files are supported at this time.
|
||||
|
||||
Arguments:
|
||||
directory: Directory where the data is located.
|
||||
If `labels` is "inferred", it should contain
|
||||
subdirectories, each containing text files for a class.
|
||||
Otherwise, the directory structure is ignored.
|
||||
labels: Either "inferred"
|
||||
(labels are generated from the directory structure),
|
||||
or a list/tuple of integer labels of the same size as the number of
|
||||
text files found in the directory. Labels should be sorted according
|
||||
to the alphanumeric order of the text file paths
|
||||
(obtained via `os.walk(directory)` in Python).
|
||||
label_mode:
|
||||
- 'int': means that the labels are encoded as integers
|
||||
(e.g. for `sparse_categorical_crossentropy` loss).
|
||||
- 'categorical' means that the labels are
|
||||
encoded as a categorical vector
|
||||
(e.g. for `categorical_crossentropy` loss).
|
||||
- 'binary' means that the labels (there can be only 2)
|
||||
are encoded as `float32` scalars with values 0 or 1
|
||||
(e.g. for `binary_crossentropy`).
|
||||
- None (no labels).
|
||||
class_names: Only valid if "labels" is "inferred". This is the explict
|
||||
list of class names (must match names of subdirectories). Used
|
||||
to control the order of the classes
|
||||
(otherwise alphanumerical order is used).
|
||||
batch_size: Size of the batches of data. Default: 32.
|
||||
max_length: Maximum size of a text string. Texts longer than this will
|
||||
be truncated to `max_length`.
|
||||
shuffle: Whether to shuffle the data. Default: True.
|
||||
If set to False, sorts the data in alphanumeric order.
|
||||
seed: Optional random seed for shuffling and transformations.
|
||||
validation_split: Optional float between 0 and 1,
|
||||
fraction of data to reserve for validation.
|
||||
subset: One of "training" or "validation".
|
||||
Only used if `validation_split` is set.
|
||||
follow_links: Whether to visits subdirectories pointed to by symlinks.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` object.
|
||||
- If `label_mode` is None, it yields `string` tensors of shape
|
||||
`(batch_size,)`, containing the contents of a batch of text files.
|
||||
- Otherwise, it yields a tuple `(texts, labels)`, where `texts`
|
||||
has shape `(batch_size,)` and `labels` follows the format described
|
||||
below.
|
||||
|
||||
Rules regarding labels format:
|
||||
- if `label_mode` is `int`, the labels are an `int32` tensor of shape
|
||||
`(batch_size,)`.
|
||||
- if `label_mode` is `binary`, the labels are a `float32` tensor of
|
||||
1s and 0s of shape `(batch_size, 1)`.
|
||||
- if `label_mode` is `categorial`, the labels are a `float32` tensor
|
||||
of shape `(batch_size, num_classes)`, representing a one-hot
|
||||
encoding of the class index.
|
||||
"""
|
||||
if labels != 'inferred':
|
||||
if not isinstance(labels, (list, tuple)):
|
||||
raise ValueError(
|
||||
'`labels` argument should be a list/tuple of integer labels, of '
|
||||
'the same size as the number of text files in the target '
|
||||
'directory. If you wish to infer the labels from the subdirectory '
|
||||
'names in the target directory, pass `labels="inferred"`. '
|
||||
'If you wish to get a dataset that only contains text samples '
|
||||
'(no labels), pass `labels=None`.')
|
||||
if class_names:
|
||||
raise ValueError('You can only pass `class_names` if the labels are '
|
||||
'inferred from the subdirectory names in the target '
|
||||
'directory (`labels="inferred"`).')
|
||||
if label_mode not in {'int', 'categorical', 'binary', None}:
|
||||
raise ValueError(
|
||||
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
||||
'or None. Received: %s' % (label_mode,))
|
||||
|
||||
if seed is None:
|
||||
seed = np.random.randint(1e6)
|
||||
file_paths, labels, class_names = dataset_utils.index_directory(
|
||||
directory,
|
||||
labels,
|
||||
formats=('.txt',),
|
||||
class_names=class_names,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
follow_links=follow_links)
|
||||
|
||||
if label_mode == 'binary' and len(class_names) != 2:
|
||||
raise ValueError(
|
||||
'When passing `label_mode="binary", there must exactly 2 classes. '
|
||||
'Found the following classes: %s' % (class_names,))
|
||||
|
||||
file_paths, labels = dataset_utils.get_training_or_validation_split(
|
||||
file_paths, labels, validation_split, subset)
|
||||
|
||||
dataset = paths_and_labels_to_dataset(
|
||||
file_paths=file_paths,
|
||||
labels=labels,
|
||||
label_mode=label_mode,
|
||||
num_classes=len(class_names),
|
||||
max_length=max_length)
|
||||
if shuffle:
|
||||
# Shuffle locally at each iteration
|
||||
dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
|
||||
dataset = dataset.batch(batch_size)
|
||||
# Users may need to reference `class_names`.
|
||||
dataset.class_names = class_names
|
||||
return dataset
|
||||
|
||||
|
||||
def paths_and_labels_to_dataset(file_paths,
|
||||
labels,
|
||||
label_mode,
|
||||
num_classes,
|
||||
max_length):
|
||||
"""Constructs a dataset of text strings and labels."""
|
||||
path_ds = dataset_ops.Dataset.from_tensor_slices(file_paths)
|
||||
string_ds = path_ds.map(
|
||||
lambda x: path_to_string_content(x, max_length))
|
||||
if label_mode:
|
||||
label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes)
|
||||
string_ds = dataset_ops.Dataset.zip((string_ds, label_ds))
|
||||
return string_ds
|
||||
|
||||
|
||||
def path_to_string_content(path, max_length):
|
||||
txt = io_ops.read_file(path)
|
||||
if max_length is not None:
|
||||
txt = string_ops.substr(txt, 0, max_length)
|
||||
return txt
|
218
tensorflow/python/keras/preprocessing/text_dataset_test.py
Normal file
218
tensorflow/python/keras/preprocessing/text_dataset_test.py
Normal file
@ -0,0 +1,218 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for text_dataset."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import string
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.preprocessing import text_dataset
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
|
||||
def _prepare_directory(self,
|
||||
num_classes=2,
|
||||
nested_dirs=False,
|
||||
count=16,
|
||||
length=20):
|
||||
# Get a unique temp directory
|
||||
temp_dir = os.path.join(self.get_temp_dir(), str(random.randint(0, 1e6)))
|
||||
os.mkdir(temp_dir)
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
|
||||
# Generate paths to class subdirectories
|
||||
paths = []
|
||||
for class_index in range(num_classes):
|
||||
class_directory = 'class_%s' % (class_index,)
|
||||
if nested_dirs:
|
||||
class_paths = [
|
||||
class_directory, os.path.join(class_directory, 'subfolder_1'),
|
||||
os.path.join(class_directory, 'subfolder_2'), os.path.join(
|
||||
class_directory, 'subfolder_1', 'sub-subfolder')
|
||||
]
|
||||
else:
|
||||
class_paths = [class_directory]
|
||||
for path in class_paths:
|
||||
os.mkdir(os.path.join(temp_dir, path))
|
||||
paths += class_paths
|
||||
|
||||
for i in range(count):
|
||||
path = paths[count % len(paths)]
|
||||
filename = os.path.join(path, 'text_%s.txt' % (i,))
|
||||
f = open(os.path.join(temp_dir, filename), 'w')
|
||||
text = ''.join([random.choice(string.printable) for _ in range(length)])
|
||||
f.write(text)
|
||||
f.close()
|
||||
return temp_dir
|
||||
|
||||
def test_text_dataset_from_directory_binary(self):
|
||||
directory = self._prepare_directory(num_classes=2)
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode='int', max_length=10)
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (8,))
|
||||
self.assertEqual(batch[0].dtype.name, 'string')
|
||||
self.assertEqual(len(batch[0].numpy()[0]), 10) # Test max_length
|
||||
self.assertEqual(batch[1].shape, (8,))
|
||||
self.assertEqual(batch[1].dtype.name, 'int32')
|
||||
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode='binary')
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (8,))
|
||||
self.assertEqual(batch[0].dtype.name, 'string')
|
||||
self.assertEqual(batch[1].shape, (8, 1))
|
||||
self.assertEqual(batch[1].dtype.name, 'float32')
|
||||
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode='categorical')
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (8,))
|
||||
self.assertEqual(batch[0].dtype.name, 'string')
|
||||
self.assertEqual(batch[1].shape, (8, 2))
|
||||
self.assertEqual(batch[1].dtype.name, 'float32')
|
||||
|
||||
def test_sample_count(self):
|
||||
directory = self._prepare_directory(num_classes=4, count=15)
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode=None)
|
||||
sample_count = 0
|
||||
for batch in dataset:
|
||||
sample_count += batch.shape[0]
|
||||
self.assertEqual(sample_count, 15)
|
||||
|
||||
def test_text_dataset_from_directory_multiclass(self):
|
||||
directory = self._prepare_directory(num_classes=4, count=15)
|
||||
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode=None)
|
||||
batch = next(iter(dataset))
|
||||
self.assertEqual(batch.shape, (8,))
|
||||
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode=None)
|
||||
sample_count = 0
|
||||
iterator = iter(dataset)
|
||||
for batch in dataset:
|
||||
sample_count += next(iterator).shape[0]
|
||||
self.assertEqual(sample_count, 15)
|
||||
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode='int')
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (8,))
|
||||
self.assertEqual(batch[0].dtype.name, 'string')
|
||||
self.assertEqual(batch[1].shape, (8,))
|
||||
self.assertEqual(batch[1].dtype.name, 'int32')
|
||||
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode='categorical')
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (8,))
|
||||
self.assertEqual(batch[0].dtype.name, 'string')
|
||||
self.assertEqual(batch[1].shape, (8, 4))
|
||||
self.assertEqual(batch[1].dtype.name, 'float32')
|
||||
|
||||
def test_text_dataset_from_directory_validation_split(self):
|
||||
directory = self._prepare_directory(num_classes=2, count=10)
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=10, validation_split=0.2, subset='training')
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (8,))
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=10, validation_split=0.2, subset='validation')
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertEqual(batch[0].shape, (2,))
|
||||
|
||||
def test_text_dataset_from_directory_manual_labels(self):
|
||||
directory = self._prepare_directory(num_classes=2, count=2)
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, labels=[0, 1], shuffle=False)
|
||||
batch = next(iter(dataset))
|
||||
self.assertLen(batch, 2)
|
||||
self.assertAllClose(batch[1], [0, 1])
|
||||
|
||||
def test_text_dataset_from_directory_follow_links(self):
|
||||
directory = self._prepare_directory(num_classes=2, count=25,
|
||||
nested_dirs=True)
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=8, label_mode=None, follow_links=True)
|
||||
sample_count = 0
|
||||
for batch in dataset:
|
||||
sample_count += batch.shape[0]
|
||||
self.assertEqual(sample_count, 25)
|
||||
|
||||
def test_text_dataset_from_directory_errors(self):
|
||||
directory = self._prepare_directory(num_classes=3, count=5)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, labels=None)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, label_mode='other')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'only pass `class_names` if the labels are inferred'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, labels=[0, 0, 1, 1, 1],
|
||||
class_names=['class_0', 'class_1', 'class_2'])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Expected the lengths of `labels` to match the number of files'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, labels=[0, 0, 1, 1])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '`class_names` passed did not match'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, class_names=['class_0', 'class_2'])
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'there must exactly 2 classes'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, label_mode='binary')
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'`validation_split` must be between 0 and 1'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, validation_split=2)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'`subset` must be either "training" or'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, validation_split=0.2, subset='other')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
test.main()
|
@ -106,8 +106,8 @@ def timeseries_dataset_from_array(
|
||||
```python
|
||||
input_data = data[:-10]
|
||||
targets = data[10:]
|
||||
dataset = tf.keras.preprocessing.timeseries.dataset_from_array(
|
||||
input_data, targets, sequence_length=10)
|
||||
dataset = tf.keras.preprocessing.timeseries_dataset_from_array(
|
||||
input_data, targets, sequence_length=10)
|
||||
for batch in dataset:
|
||||
inputs, targets = batch
|
||||
assert np.array_equal(inputs[0], data[:10]) # First sequence: steps [0-9]
|
||||
|
Loading…
Reference in New Issue
Block a user