The final leftover piece is keras.layer. PiperOrigin-RevId: 276390172 Change-Id: Ie2efc73e9b987df15fb8085f82b8369eca8ce664
144 lines
5.0 KiB
Python
144 lines
5.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
# ==============================================================================
|
||
"""Tests for text data preprocessing utils."""
|
||
|
||
from __future__ import absolute_import
|
||
from __future__ import division
|
||
from __future__ import print_function
|
||
|
||
import numpy as np
|
||
|
||
from tensorflow.python.keras.preprocessing import text as preprocessing_text
|
||
from tensorflow.python.platform import test
|
||
|
||
|
||
class TestText(test.TestCase):
|
||
|
||
def test_one_hot(self):
|
||
text = 'The cat sat on the mat.'
|
||
encoded = preprocessing_text.one_hot(text, 5)
|
||
self.assertEqual(len(encoded), 6)
|
||
self.assertLessEqual(np.max(encoded), 4)
|
||
self.assertGreaterEqual(np.min(encoded), 0)
|
||
|
||
# Test on unicode.
|
||
text = u'The cat sat on the mat.'
|
||
encoded = preprocessing_text.one_hot(text, 5)
|
||
self.assertEqual(len(encoded), 6)
|
||
self.assertLessEqual(np.max(encoded), 4)
|
||
self.assertGreaterEqual(np.min(encoded), 0)
|
||
|
||
def test_tokenizer(self):
|
||
texts = [
|
||
'The cat sat on the mat.',
|
||
'The dog sat on the log.',
|
||
'Dogs and cats living together.'
|
||
]
|
||
tokenizer = preprocessing_text.Tokenizer(num_words=10)
|
||
tokenizer.fit_on_texts(texts)
|
||
|
||
sequences = []
|
||
for seq in tokenizer.texts_to_sequences_generator(texts):
|
||
sequences.append(seq)
|
||
self.assertLess(np.max(np.max(sequences)), 10)
|
||
self.assertEqual(np.min(np.min(sequences)), 1)
|
||
|
||
tokenizer.fit_on_sequences(sequences)
|
||
|
||
for mode in ['binary', 'count', 'tfidf', 'freq']:
|
||
matrix = tokenizer.texts_to_matrix(texts, mode)
|
||
self.assertEqual(matrix.shape, (3, 10))
|
||
|
||
def test_hashing_trick_hash(self):
|
||
text = 'The cat sat on the mat.'
|
||
encoded = preprocessing_text.hashing_trick(text, 5)
|
||
self.assertEqual(len(encoded), 6)
|
||
self.assertLessEqual(np.max(encoded), 4)
|
||
self.assertGreaterEqual(np.min(encoded), 1)
|
||
|
||
def test_hashing_trick_md5(self):
|
||
text = 'The cat sat on the mat.'
|
||
encoded = preprocessing_text.hashing_trick(
|
||
text, 5, hash_function='md5')
|
||
self.assertEqual(len(encoded), 6)
|
||
self.assertLessEqual(np.max(encoded), 4)
|
||
self.assertGreaterEqual(np.min(encoded), 1)
|
||
|
||
def test_tokenizer_oov_flag(self):
|
||
x_train = ['This text has only known words']
|
||
x_test = ['This text has some unknown words'] # 2 OOVs: some, unknown
|
||
|
||
# Default, without OOV flag
|
||
tokenizer = preprocessing_text.Tokenizer()
|
||
tokenizer.fit_on_texts(x_train)
|
||
x_test_seq = tokenizer.texts_to_sequences(x_test)
|
||
self.assertEqual(len(x_test_seq[0]), 4) # discards 2 OOVs
|
||
|
||
# With OOV feature
|
||
tokenizer = preprocessing_text.Tokenizer(oov_token='<unk>')
|
||
tokenizer.fit_on_texts(x_train)
|
||
x_test_seq = tokenizer.texts_to_sequences(x_test)
|
||
self.assertEqual(len(x_test_seq[0]), 6) # OOVs marked in place
|
||
|
||
def test_sequential_fit(self):
|
||
texts = [
|
||
'The cat sat on the mat.', 'The dog sat on the log.',
|
||
'Dogs and cats living together.'
|
||
]
|
||
word_sequences = [['The', 'cat', 'is', 'sitting'],
|
||
['The', 'dog', 'is', 'standing']]
|
||
tokenizer = preprocessing_text.Tokenizer()
|
||
tokenizer.fit_on_texts(texts)
|
||
tokenizer.fit_on_texts(word_sequences)
|
||
|
||
self.assertEqual(tokenizer.document_count, 5)
|
||
|
||
tokenizer.texts_to_matrix(texts)
|
||
tokenizer.texts_to_matrix(word_sequences)
|
||
|
||
def test_text_to_word_sequence(self):
|
||
text = 'hello! ? world!'
|
||
seq = preprocessing_text.text_to_word_sequence(text)
|
||
self.assertEqual(seq, ['hello', 'world'])
|
||
|
||
def test_text_to_word_sequence_multichar_split(self):
|
||
text = 'hello!stop?world!'
|
||
seq = preprocessing_text.text_to_word_sequence(text, split='stop')
|
||
self.assertEqual(seq, ['hello', 'world'])
|
||
|
||
def test_text_to_word_sequence_unicode(self):
|
||
text = u'ali! veli? kırk dokuz elli'
|
||
seq = preprocessing_text.text_to_word_sequence(text)
|
||
self.assertEqual(seq, [u'ali', u'veli', u'kırk', u'dokuz', u'elli'])
|
||
|
||
def test_text_to_word_sequence_unicode_multichar_split(self):
|
||
text = u'ali!stopveli?stopkırkstopdokuzstopelli'
|
||
seq = preprocessing_text.text_to_word_sequence(text, split='stop')
|
||
self.assertEqual(seq, [u'ali', u'veli', u'kırk', u'dokuz', u'elli'])
|
||
|
||
def test_tokenizer_unicode(self):
|
||
texts = [
|
||
u'ali veli kırk dokuz elli', u'ali veli kırk dokuz elli veli kırk dokuz'
|
||
]
|
||
tokenizer = preprocessing_text.Tokenizer(num_words=5)
|
||
tokenizer.fit_on_texts(texts)
|
||
|
||
self.assertEqual(len(tokenizer.word_counts), 5)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
test.main()
|