From 09e67c47ee7ce8139093f9a455b7c7a94877193d Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 22 Apr 2020 10:25:47 -0700 Subject: [PATCH] Create `text_dataset_from_directory` utility and refactor shared code between `text_dataset_from_directory` and `image_dataset_from_directory`. PiperOrigin-RevId: 307844332 Change-Id: I4b5f094dfa98a71d5ffa59c3134794bce765a575 --- tensorflow/python/keras/preprocessing/BUILD | 16 ++ .../keras/preprocessing/dataset_utils.py | 201 ++++++++++++++++ .../keras/preprocessing/image_dataset.py | 172 +++----------- .../keras/preprocessing/image_dataset_test.py | 4 +- tensorflow/python/keras/preprocessing/text.py | 1 + .../keras/preprocessing/text_dataset.py | 186 +++++++++++++++ .../keras/preprocessing/text_dataset_test.py | 218 ++++++++++++++++++ .../python/keras/preprocessing/timeseries.py | 4 +- 8 files changed, 661 insertions(+), 141 deletions(-) create mode 100644 tensorflow/python/keras/preprocessing/dataset_utils.py create mode 100644 tensorflow/python/keras/preprocessing/text_dataset.py create mode 100644 tensorflow/python/keras/preprocessing/text_dataset_test.py diff --git a/tensorflow/python/keras/preprocessing/BUILD b/tensorflow/python/keras/preprocessing/BUILD index 3cfdb1e2c78..403bc6e4808 100644 --- a/tensorflow/python/keras/preprocessing/BUILD +++ b/tensorflow/python/keras/preprocessing/BUILD @@ -31,6 +31,7 @@ py_library( py_library( name = "image", srcs = [ + "dataset_utils.py", "image.py", "image_dataset.py", ], @@ -69,7 +70,9 @@ py_library( py_library( name = "text", srcs = [ + "dataset_utils.py", "text.py", + "text_dataset.py", ], deps = ["//tensorflow/python:util"], ) @@ -124,6 +127,19 @@ tf_py_test( ], ) +tf_py_test( + name = "text_dataset_test", + size = "small", + srcs = ["text_dataset_test.py"], + python_version = "PY3", + deps = [ + ":text", + "//tensorflow/python:client_testlib", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/keras", + ], +) + tf_py_test( name = "timeseries_test", size = "small", diff --git a/tensorflow/python/keras/preprocessing/dataset_utils.py b/tensorflow/python/keras/preprocessing/dataset_utils.py new file mode 100644 index 00000000000..70d9566889f --- /dev/null +++ b/tensorflow/python/keras/preprocessing/dataset_utils.py @@ -0,0 +1,201 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras image dataset loading utilities.""" +# pylint: disable=g-classes-have-attributes +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import multiprocessing +import os + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def index_directory(directory, + labels, + formats, + class_names=None, + shuffle=True, + seed=None, + follow_links=False): + """Make list of all files in the subdirs of `directory`, with their labels. + + Args: + directory: The target directory (string). + labels: Either "inferred" + (labels are generated from the directory structure), + or a list/tuple of integer labels of the same size as the number of + valid files found in the directory. Labels should be sorted according + to the alphanumeric order of the image file paths + (obtained via `os.walk(directory)` in Python). + formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt"). + class_names: Only valid if "labels" is "inferred". This is the explict + list of class names (must match names of subdirectories). Used + to control the order of the classes + (otherwise alphanumerical order is used). + shuffle: Whether to shuffle the data. Default: True. + If set to False, sorts the data in alphanumeric order. + seed: Optional random seed for shuffling. + follow_links: Whether to visits subdirectories pointed to by symlinks. + + Returns: + tuple (file_paths, labels, class_names). + file_paths: list of file paths (strings). + labels: list of matching integer labels (same length as file_paths) + class_names: names of the classes corresponding to these labels, in order. + """ + inferred_class_names = [] + for subdir in sorted(os.listdir(directory)): + if os.path.isdir(os.path.join(directory, subdir)): + inferred_class_names.append(subdir) + if not class_names: + class_names = inferred_class_names + else: + if set(class_names) != set(inferred_class_names): + raise ValueError( + 'The `class_names` passed did not match the ' + 'names of the subdirectories of the target directory. ' + 'Expected: %s, but received: %s' % + (inferred_class_names, class_names)) + class_indices = dict(zip(class_names, range(len(class_names)))) + + # Build an index of the files + # in the different class subfolders. + pool = multiprocessing.pool.ThreadPool() + results = [] + filenames = [] + for dirpath in (os.path.join(directory, subdir) for subdir in class_names): + results.append( + pool.apply_async(index_subdirectory, + (dirpath, class_indices, follow_links, formats))) + labels_list = [] + for res in results: + partial_filenames, partial_labels = res.get() + labels_list.append(partial_labels) + filenames += partial_filenames + if labels != 'inferred': + if len(labels) != len(filenames): + raise ValueError('Expected the lengths of `labels` to match the number ' + 'of files in the target directory. len(labels) is %s ' + 'while we found %s files in %s.' % ( + len(labels), len(filenames), directory)) + else: + i = 0 + labels = np.zeros((len(filenames),), dtype='int32') + for partial_labels in labels_list: + labels[i:i + len(partial_labels)] = partial_labels + i += len(partial_labels) + + print('Found %d files belonging to %d classes.' % + (len(filenames), len(class_names))) + pool.close() + pool.join() + file_paths = [os.path.join(directory, fname) for fname in filenames] + + if shuffle: + # Shuffle globally to erase macro-structure + if seed is None: + seed = np.random.randint(1e6) + rng = np.random.RandomState(seed) + rng.shuffle(file_paths) + rng = np.random.RandomState(seed) + rng.shuffle(labels) + return file_paths, labels, class_names + + +def iter_valid_files(directory, follow_links, formats): + walk = os.walk(directory, followlinks=follow_links) + for root, _, files in sorted(walk, key=lambda x: x[0]): + for fname in sorted(files): + if fname.lower().endswith(formats): + yield root, fname + + +def index_subdirectory(directory, class_indices, follow_links, formats): + """Recursively walks directory and list image paths and their class index. + + Arguments: + directory: string, target directory. + class_indices: dict mapping class names to their index. + follow_links: boolean, whether to recursively follow subdirectories + (if False, we only list top-level images in `directory`). + formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt"). + + Returns: + tuple `(filenames, labels)`. `filenames` is a list of relative file + paths, and `labels` is a list of integer labels corresponding to these + files. + """ + dirname = os.path.basename(directory) + valid_files = iter_valid_files(directory, follow_links, formats) + labels = [] + filenames = [] + for root, fname in valid_files: + labels.append(class_indices[dirname]) + absolute_path = os.path.join(root, fname) + relative_path = os.path.join( + dirname, os.path.relpath(absolute_path, directory)) + filenames.append(relative_path) + return filenames, labels + + +def get_training_or_validation_split(samples, labels, validation_split, subset): + """Potentially restict samples & labels to a training or validation split. + + Args: + samples: List of elements. + labels: List of corresponding labels. + validation_split: Float, fraction of data to reserve for validation. + subset: Subset of the data to return. + Either "training", "validation", or None. If None, we return all of the + data. + + Returns: + tuple (samples, labels), potentially restricted to the specified subset. + """ + if validation_split: + if not 0 < validation_split < 1: + raise ValueError( + '`validation_split` must be between 0 and 1, received: %s' % + (validation_split,)) + if subset is None: + return samples, labels + + num_val_samples = int(validation_split * len(samples)) + if subset == 'training': + samples = samples[:-num_val_samples] + labels = labels[:-num_val_samples] + elif subset == 'validation': + samples = samples[-num_val_samples:] + labels = labels[-num_val_samples:] + else: + raise ValueError('`subset` must be either "training" ' + 'or "validation", received: %s' % (subset,)) + return samples, labels + + +def labels_to_dataset(labels, label_mode, num_classes): + label_ds = dataset_ops.Dataset.from_tensor_slices(labels) + if label_mode == 'binary': + label_ds = label_ds.map( + lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1)) + elif label_mode == 'categorical': + label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes)) + return label_ds diff --git a/tensorflow/python/keras/preprocessing/image_dataset.py b/tensorflow/python/keras/preprocessing/image_dataset.py index 500a41fc8c5..2e24ef887ae 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset.py +++ b/tensorflow/python/keras/preprocessing/image_dataset.py @@ -18,17 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import multiprocessing -import os - import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.keras.layers.preprocessing import image_preprocessing -from tensorflow.python.ops import array_ops +from tensorflow.python.keras.preprocessing import dataset_utils from tensorflow.python.ops import image_ops from tensorflow.python.ops import io_ops -from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import keras_export @@ -49,7 +45,7 @@ def image_dataset_from_directory(directory, subset=None, interpolation='bilinear', follow_links=False): - """Generates a Dataset from image files in a directory. + """Generates a `tf.data.Dataset` from image files in a directory. If your directory structure is: @@ -63,10 +59,10 @@ def image_dataset_from_directory(directory, ......b_image_2.jpg ``` - Then calling `from_directory(main_directory, labels='inferred')` - will return a Dataset that yields batches of images from + Then calling `image_dataset_from_directory(main_directory, labels='inferred')` + will return a `tf.data.Dataset` that yields batches of images from the subdirectories `class_a` and `class_b`, together with labels - 0 and 1 (0 corresponding to class_a and 1 corresponding to class_b). + 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). Supported image formats: jpeg, png, bmp, gif. Animated gifs are truncated to the first frame. @@ -126,22 +122,22 @@ def image_dataset_from_directory(directory, has shape `(batch_size, image_size[0], image_size[1], num_channels)`, and `labels` follows the format described below. - Rules regarding labels format: - - if `label_mode` is `int`, the labels are an `int32` tensor of shape - `(batch_size,)`. - - if `label_mode` is `binary`, the labels are a `float32` tensor of - 1s and 0s of shape `(batch_size, 1)`. - - if `label_mode` is `categorial`, the labels are a `float32` tensor - of shape `(batch_size, num_classes)`, representing a one-hot - encoding of the class index. + Rules regarding labels format: + - if `label_mode` is `int`, the labels are an `int32` tensor of shape + `(batch_size,)`. + - if `label_mode` is `binary`, the labels are a `float32` tensor of + 1s and 0s of shape `(batch_size, 1)`. + - if `label_mode` is `categorial`, the labels are a `float32` tensor + of shape `(batch_size, num_classes)`, representing a one-hot + encoding of the class index. - Rules regarding number of channels in the yielded images: - - if `color_mode` is `grayscale`, - there's 1 channel in the image tensors. - - if `color_mode` is `rgb`, - there are 3 channel in the image tensors. - - if `color_mode` is `rgba`, - there are 4 channel in the image tensors. + Rules regarding number of channels in the yielded images: + - if `color_mode` is `grayscale`, + there's 1 channel in the image tensors. + - if `color_mode` is `rgb`, + there are 3 channel in the image tensors. + - if `color_mode` is `rgba`, + there are 4 channel in the image tensors. """ if labels != 'inferred': if not isinstance(labels, (list, tuple)): @@ -172,85 +168,25 @@ def image_dataset_from_directory(directory, 'Received: %s' % (color_mode,)) interpolation = image_preprocessing.get_interpolation(interpolation) - inferred_class_names = [] - for subdir in sorted(os.listdir(directory)): - if os.path.isdir(os.path.join(directory, subdir)): - inferred_class_names.append(subdir) - if not class_names: - class_names = inferred_class_names - else: - if set(class_names) != set(inferred_class_names): - raise ValueError( - 'The `class_names` passed did not match the ' - 'names of the subdirectories of the target directory. ' - 'Expected: %s, but received: %s' % - (inferred_class_names, class_names)) - class_indices = dict(zip(class_names, range(len(class_names)))) + if seed is None: + seed = np.random.randint(1e6) + image_paths, labels, class_names = dataset_utils.index_directory( + directory, + labels, + formats=WHITELIST_FORMATS, + class_names=class_names, + shuffle=shuffle, + seed=seed, + follow_links=follow_links) if label_mode == 'binary' and len(class_names) != 2: raise ValueError( 'When passing `label_mode="binary", there must exactly 2 classes. ' 'Found the following classes: %s' % (class_names,)) - # Build an index of the images - # in the different class subfolders. - pool = multiprocessing.pool.ThreadPool() - results = [] - filenames = [] - for dirpath in (os.path.join(directory, subdir) for subdir in class_names): - results.append( - pool.apply_async(list_labeled_images_in_directory, - (dirpath, class_indices, follow_links))) - labels_list = [] - for res in results: - partial_labels, partial_filenames = res.get() - labels_list.append(partial_labels) - filenames += partial_filenames - if labels != 'inferred': - if len(labels) != len(filenames): - raise ValueError('Expected the lengths of `labels` to match the number ' - 'of images in the target directory. len(labels) is %s ' - 'while we found %s images in %s.' % ( - len(labels), len(filenames), directory)) - else: - i = 0 - labels = np.zeros((len(filenames),), dtype='int32') - for partial_labels in labels_list: - labels[i:i + len(partial_labels)] = partial_labels - i += len(partial_labels) + image_paths, labels = dataset_utils.get_training_or_validation_split( + image_paths, labels, validation_split, subset) - print('Found %d images belonging to %d classes.' % - (len(filenames), len(class_names))) - pool.close() - pool.join() - image_paths = [os.path.join(directory, fname) for fname in filenames] - - if shuffle: - # Shuffle globally to erase macro-structure - # (the dataset will be further shuffled within a local buffer - # at each iteration) - if seed is None: - seed = np.random.randint(1e6) - rng = np.random.RandomState(seed) - rng.shuffle(image_paths) - rng = np.random.RandomState(seed) - rng.shuffle(labels) - - if validation_split: - if not 0 < validation_split < 1: - raise ValueError( - '`validation_split` must be between 0 and 1, received: %s' % - (validation_split,)) - num_val_samples = int(validation_split * len(image_paths)) - if subset == 'training': - image_paths = image_paths[:-num_val_samples] - labels = labels[:-num_val_samples] - elif subset == 'validation': - image_paths = image_paths[-num_val_samples:] - labels = labels[-num_val_samples:] - else: - raise ValueError('`subset` must be either "training" ' - 'or "validation", received: %s' % (subset,)) dataset = paths_and_labels_to_dataset( image_paths=image_paths, image_size=image_size, @@ -263,6 +199,8 @@ def image_dataset_from_directory(directory, # Shuffle locally at each iteration dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) dataset = dataset.batch(batch_size) + # Users may need to reference `class_names`. + dataset.class_names = class_names return dataset @@ -279,51 +217,11 @@ def paths_and_labels_to_dataset(image_paths, img_ds = path_ds.map( lambda x: path_to_image(x, image_size, num_channels, interpolation)) if label_mode: - label_ds = dataset_ops.Dataset.from_tensor_slices(labels) - if label_mode == 'binary': - label_ds = label_ds.map( - lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1)) - elif label_mode == 'categorical': - label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes)) + label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes) img_ds = dataset_ops.Dataset.zip((img_ds, label_ds)) return img_ds -def iter_valid_files(directory, follow_links): - walk = os.walk(directory, followlinks=follow_links) - for root, _, files in sorted(walk, key=lambda x: x[0]): - for fname in sorted(files): - if fname.lower().endswith(WHITELIST_FORMATS): - yield root, fname - - -def list_labeled_images_in_directory(directory, class_indices, follow_links): - """Recursively walks directory and list image paths and their class index. - - Arguments: - directory: string, target directory. - class_indices: dict mapping class names to their index. - follow_links: boolean, whether to recursively follow subdirectories - (if False, we only list top-level images in `directory`). - - Returns: - tuple `(labels, filenames)`. `labels` is a list of integer - labels and `filenames` is a list of relative image paths corresponding - to these labels. - """ - dirname = os.path.basename(directory) - valid_files = iter_valid_files(directory, follow_links) - labels = [] - filenames = [] - for root, fname in valid_files: - labels.append(class_indices[dirname]) - absolute_path = os.path.join(root, fname) - relative_path = os.path.join( - dirname, os.path.relpath(absolute_path, directory)) - filenames.append(relative_path) - return labels, filenames - - def path_to_image(path, image_size, num_channels, interpolation): img = io_ops.read_file(path) img = image_ops.decode_image( diff --git a/tensorflow/python/keras/preprocessing/image_dataset_test.py b/tensorflow/python/keras/preprocessing/image_dataset_test.py index 629f03d4494..aa10c1c7bac 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset_test.py +++ b/tensorflow/python/keras/preprocessing/image_dataset_test.py @@ -35,7 +35,7 @@ except ImportError: PIL = None -class DatasetFromDirectoryTest(keras_parameterized.TestCase): +class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase): def _get_images(self, count=16, color_mode='rgb'): width = height = 24 @@ -262,7 +262,7 @@ class DatasetFromDirectoryTest(keras_parameterized.TestCase): with self.assertRaisesRegex( ValueError, - 'Expected the lengths of `labels` to match the number of images'): + 'Expected the lengths of `labels` to match the number of files'): _ = image_dataset.image_dataset_from_directory( directory, labels=[0, 0, 1, 1]) diff --git a/tensorflow/python/keras/preprocessing/text.py b/tensorflow/python/keras/preprocessing/text.py index e501789a1a0..5a343e0069c 100644 --- a/tensorflow/python/keras/preprocessing/text.py +++ b/tensorflow/python/keras/preprocessing/text.py @@ -21,6 +21,7 @@ from __future__ import print_function from keras_preprocessing import text +from tensorflow.python.keras.preprocessing.text_dataset import text_dataset_from_directory # pylint: disable=unused-import from tensorflow.python.util.tf_export import keras_export hashing_trick = text.hashing_trick diff --git a/tensorflow/python/keras/preprocessing/text_dataset.py b/tensorflow/python/keras/preprocessing/text_dataset.py new file mode 100644 index 00000000000..8748f7258d8 --- /dev/null +++ b/tensorflow/python/keras/preprocessing/text_dataset.py @@ -0,0 +1,186 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras text dataset generation utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.keras.preprocessing import dataset_utils +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import string_ops + + +def text_dataset_from_directory(directory, + labels='inferred', + label_mode='int', + class_names=None, + batch_size=32, + max_length=None, + shuffle=True, + seed=None, + validation_split=None, + subset=None, + follow_links=False): + """Generates a `tf.data.Dataset` from text files in a directory. + + If your directory structure is: + + ``` + main_directory/ + ...class_a/ + ......a_text_1.txt + ......a_text_2.txt + ...class_b/ + ......b_text_1.txt + ......b_text_2.txt + ``` + + Then calling `text_dataset_from_directory(main_directory, labels='inferred')` + will return a `tf.data.Dataset` that yields batches of texts from + the subdirectories `class_a` and `class_b`, together with labels + 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). + + Only `.txt` files are supported at this time. + + Arguments: + directory: Directory where the data is located. + If `labels` is "inferred", it should contain + subdirectories, each containing text files for a class. + Otherwise, the directory structure is ignored. + labels: Either "inferred" + (labels are generated from the directory structure), + or a list/tuple of integer labels of the same size as the number of + text files found in the directory. Labels should be sorted according + to the alphanumeric order of the text file paths + (obtained via `os.walk(directory)` in Python). + label_mode: + - 'int': means that the labels are encoded as integers + (e.g. for `sparse_categorical_crossentropy` loss). + - 'categorical' means that the labels are + encoded as a categorical vector + (e.g. for `categorical_crossentropy` loss). + - 'binary' means that the labels (there can be only 2) + are encoded as `float32` scalars with values 0 or 1 + (e.g. for `binary_crossentropy`). + - None (no labels). + class_names: Only valid if "labels" is "inferred". This is the explict + list of class names (must match names of subdirectories). Used + to control the order of the classes + (otherwise alphanumerical order is used). + batch_size: Size of the batches of data. Default: 32. + max_length: Maximum size of a text string. Texts longer than this will + be truncated to `max_length`. + shuffle: Whether to shuffle the data. Default: True. + If set to False, sorts the data in alphanumeric order. + seed: Optional random seed for shuffling and transformations. + validation_split: Optional float between 0 and 1, + fraction of data to reserve for validation. + subset: One of "training" or "validation". + Only used if `validation_split` is set. + follow_links: Whether to visits subdirectories pointed to by symlinks. + Defaults to False. + + Returns: + A `tf.data.Dataset` object. + - If `label_mode` is None, it yields `string` tensors of shape + `(batch_size,)`, containing the contents of a batch of text files. + - Otherwise, it yields a tuple `(texts, labels)`, where `texts` + has shape `(batch_size,)` and `labels` follows the format described + below. + + Rules regarding labels format: + - if `label_mode` is `int`, the labels are an `int32` tensor of shape + `(batch_size,)`. + - if `label_mode` is `binary`, the labels are a `float32` tensor of + 1s and 0s of shape `(batch_size, 1)`. + - if `label_mode` is `categorial`, the labels are a `float32` tensor + of shape `(batch_size, num_classes)`, representing a one-hot + encoding of the class index. + """ + if labels != 'inferred': + if not isinstance(labels, (list, tuple)): + raise ValueError( + '`labels` argument should be a list/tuple of integer labels, of ' + 'the same size as the number of text files in the target ' + 'directory. If you wish to infer the labels from the subdirectory ' + 'names in the target directory, pass `labels="inferred"`. ' + 'If you wish to get a dataset that only contains text samples ' + '(no labels), pass `labels=None`.') + if class_names: + raise ValueError('You can only pass `class_names` if the labels are ' + 'inferred from the subdirectory names in the target ' + 'directory (`labels="inferred"`).') + if label_mode not in {'int', 'categorical', 'binary', None}: + raise ValueError( + '`label_mode` argument must be one of "int", "categorical", "binary", ' + 'or None. Received: %s' % (label_mode,)) + + if seed is None: + seed = np.random.randint(1e6) + file_paths, labels, class_names = dataset_utils.index_directory( + directory, + labels, + formats=('.txt',), + class_names=class_names, + shuffle=shuffle, + seed=seed, + follow_links=follow_links) + + if label_mode == 'binary' and len(class_names) != 2: + raise ValueError( + 'When passing `label_mode="binary", there must exactly 2 classes. ' + 'Found the following classes: %s' % (class_names,)) + + file_paths, labels = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, subset) + + dataset = paths_and_labels_to_dataset( + file_paths=file_paths, + labels=labels, + label_mode=label_mode, + num_classes=len(class_names), + max_length=max_length) + if shuffle: + # Shuffle locally at each iteration + dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) + dataset = dataset.batch(batch_size) + # Users may need to reference `class_names`. + dataset.class_names = class_names + return dataset + + +def paths_and_labels_to_dataset(file_paths, + labels, + label_mode, + num_classes, + max_length): + """Constructs a dataset of text strings and labels.""" + path_ds = dataset_ops.Dataset.from_tensor_slices(file_paths) + string_ds = path_ds.map( + lambda x: path_to_string_content(x, max_length)) + if label_mode: + label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes) + string_ds = dataset_ops.Dataset.zip((string_ds, label_ds)) + return string_ds + + +def path_to_string_content(path, max_length): + txt = io_ops.read_file(path) + if max_length is not None: + txt = string_ops.substr(txt, 0, max_length) + return txt diff --git a/tensorflow/python/keras/preprocessing/text_dataset_test.py b/tensorflow/python/keras/preprocessing/text_dataset_test.py new file mode 100644 index 00000000000..c0e231e69a9 --- /dev/null +++ b/tensorflow/python/keras/preprocessing/text_dataset_test.py @@ -0,0 +1,218 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for text_dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random +import shutil +import string + +from tensorflow.python.compat import v2_compat +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras.preprocessing import text_dataset +from tensorflow.python.platform import test + + +class TextDatasetFromDirectoryTest(keras_parameterized.TestCase): + + def _prepare_directory(self, + num_classes=2, + nested_dirs=False, + count=16, + length=20): + # Get a unique temp directory + temp_dir = os.path.join(self.get_temp_dir(), str(random.randint(0, 1e6))) + os.mkdir(temp_dir) + self.addCleanup(shutil.rmtree, temp_dir) + + # Generate paths to class subdirectories + paths = [] + for class_index in range(num_classes): + class_directory = 'class_%s' % (class_index,) + if nested_dirs: + class_paths = [ + class_directory, os.path.join(class_directory, 'subfolder_1'), + os.path.join(class_directory, 'subfolder_2'), os.path.join( + class_directory, 'subfolder_1', 'sub-subfolder') + ] + else: + class_paths = [class_directory] + for path in class_paths: + os.mkdir(os.path.join(temp_dir, path)) + paths += class_paths + + for i in range(count): + path = paths[count % len(paths)] + filename = os.path.join(path, 'text_%s.txt' % (i,)) + f = open(os.path.join(temp_dir, filename), 'w') + text = ''.join([random.choice(string.printable) for _ in range(length)]) + f.write(text) + f.close() + return temp_dir + + def test_text_dataset_from_directory_binary(self): + directory = self._prepare_directory(num_classes=2) + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode='int', max_length=10) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8,)) + self.assertEqual(batch[0].dtype.name, 'string') + self.assertEqual(len(batch[0].numpy()[0]), 10) # Test max_length + self.assertEqual(batch[1].shape, (8,)) + self.assertEqual(batch[1].dtype.name, 'int32') + + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode='binary') + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8,)) + self.assertEqual(batch[0].dtype.name, 'string') + self.assertEqual(batch[1].shape, (8, 1)) + self.assertEqual(batch[1].dtype.name, 'float32') + + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode='categorical') + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8,)) + self.assertEqual(batch[0].dtype.name, 'string') + self.assertEqual(batch[1].shape, (8, 2)) + self.assertEqual(batch[1].dtype.name, 'float32') + + def test_sample_count(self): + directory = self._prepare_directory(num_classes=4, count=15) + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode=None) + sample_count = 0 + for batch in dataset: + sample_count += batch.shape[0] + self.assertEqual(sample_count, 15) + + def test_text_dataset_from_directory_multiclass(self): + directory = self._prepare_directory(num_classes=4, count=15) + + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode=None) + batch = next(iter(dataset)) + self.assertEqual(batch.shape, (8,)) + + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode=None) + sample_count = 0 + iterator = iter(dataset) + for batch in dataset: + sample_count += next(iterator).shape[0] + self.assertEqual(sample_count, 15) + + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode='int') + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8,)) + self.assertEqual(batch[0].dtype.name, 'string') + self.assertEqual(batch[1].shape, (8,)) + self.assertEqual(batch[1].dtype.name, 'int32') + + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode='categorical') + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8,)) + self.assertEqual(batch[0].dtype.name, 'string') + self.assertEqual(batch[1].shape, (8, 4)) + self.assertEqual(batch[1].dtype.name, 'float32') + + def test_text_dataset_from_directory_validation_split(self): + directory = self._prepare_directory(num_classes=2, count=10) + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=10, validation_split=0.2, subset='training') + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8,)) + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=10, validation_split=0.2, subset='validation') + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (2,)) + + def test_text_dataset_from_directory_manual_labels(self): + directory = self._prepare_directory(num_classes=2, count=2) + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, labels=[0, 1], shuffle=False) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertAllClose(batch[1], [0, 1]) + + def test_text_dataset_from_directory_follow_links(self): + directory = self._prepare_directory(num_classes=2, count=25, + nested_dirs=True) + dataset = text_dataset.text_dataset_from_directory( + directory, batch_size=8, label_mode=None, follow_links=True) + sample_count = 0 + for batch in dataset: + sample_count += batch.shape[0] + self.assertEqual(sample_count, 25) + + def test_text_dataset_from_directory_errors(self): + directory = self._prepare_directory(num_classes=3, count=5) + + with self.assertRaisesRegex(ValueError, '`labels` argument should be'): + _ = text_dataset.text_dataset_from_directory( + directory, labels=None) + + with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'): + _ = text_dataset.text_dataset_from_directory( + directory, label_mode='other') + + with self.assertRaisesRegex( + ValueError, 'only pass `class_names` if the labels are inferred'): + _ = text_dataset.text_dataset_from_directory( + directory, labels=[0, 0, 1, 1, 1], + class_names=['class_0', 'class_1', 'class_2']) + + with self.assertRaisesRegex( + ValueError, + 'Expected the lengths of `labels` to match the number of files'): + _ = text_dataset.text_dataset_from_directory( + directory, labels=[0, 0, 1, 1]) + + with self.assertRaisesRegex( + ValueError, '`class_names` passed did not match'): + _ = text_dataset.text_dataset_from_directory( + directory, class_names=['class_0', 'class_2']) + + with self.assertRaisesRegex(ValueError, 'there must exactly 2 classes'): + _ = text_dataset.text_dataset_from_directory( + directory, label_mode='binary') + + with self.assertRaisesRegex(ValueError, + '`validation_split` must be between 0 and 1'): + _ = text_dataset.text_dataset_from_directory( + directory, validation_split=2) + + with self.assertRaisesRegex(ValueError, + '`subset` must be either "training" or'): + _ = text_dataset.text_dataset_from_directory( + directory, validation_split=0.2, subset='other') + + +if __name__ == '__main__': + v2_compat.enable_v2_behavior() + test.main() diff --git a/tensorflow/python/keras/preprocessing/timeseries.py b/tensorflow/python/keras/preprocessing/timeseries.py index 373594b9356..64e2d06554d 100644 --- a/tensorflow/python/keras/preprocessing/timeseries.py +++ b/tensorflow/python/keras/preprocessing/timeseries.py @@ -106,8 +106,8 @@ def timeseries_dataset_from_array( ```python input_data = data[:-10] targets = data[10:] - dataset = tf.keras.preprocessing.timeseries.dataset_from_array( - input_data, targets, sequence_length=10) + dataset = tf.keras.preprocessing.timeseries_dataset_from_array( + input_data, targets, sequence_length=10) for batch in dataset: inputs, targets = batch assert np.array_equal(inputs[0], data[:10]) # First sequence: steps [0-9]