Automated rollback of commit 0d2f3be5eb
PiperOrigin-RevId: 296320816 Change-Id: Ib8b5857178fa10513755de65ffcde1adf6dabad3
This commit is contained in:
parent
7f6685951b
commit
0685f70521
@ -303,10 +303,9 @@ cuda_py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "preprocessing_normalization_test",
|
name = "normalization_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["normalization_test.py"],
|
srcs = ["normalization_test.py"],
|
||||||
main = "normalization_test.py",
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":normalization",
|
":normalization",
|
||||||
@ -317,10 +316,9 @@ tf_py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "preprocessing_text_vectorization_test",
|
name = "text_vectorization_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["text_vectorization_test.py"],
|
srcs = ["text_vectorization_test.py"],
|
||||||
main = "text_vectorization_test.py",
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":preprocessing_test_utils",
|
":preprocessing_test_utils",
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
# The string tokens in the extracted vocabulary
|
# The string tokens in the extracted vocabulary
|
||||||
@ -66,7 +67,13 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
1. If this value is more than 1, OOV inputs are hashed to determine their
|
1. If this value is more than 1, OOV inputs are hashed to determine their
|
||||||
OOV value; if this value is 0, passing an OOV input will result in a
|
OOV value; if this value is 0, passing an OOV input will result in a
|
||||||
runtime error.
|
runtime error.
|
||||||
vocabulary: An optional list of vocabulary terms.
|
vocabulary: An optional list of vocabulary terms, or a path to a text file
|
||||||
|
containing a vocabulary to load into this layer. The file should contain
|
||||||
|
one token per line. In either case, the vocabulary must be unique; if
|
||||||
|
the list or file contains the same token multiple times, an error will
|
||||||
|
be thrown. Note that when passing a vocabulary - either as a list or as
|
||||||
|
a file - the vocabulary will not be present in the layer's config dict;
|
||||||
|
it will instead be a part of the layer's weights.
|
||||||
reserve_zero: Whether to reserve the index 0, which indicates pad values in
|
reserve_zero: Whether to reserve the index 0, which indicates pad values in
|
||||||
the Keras masking system. If True, the output of this layer will be in the
|
the Keras masking system. If True, the output of this layer will be in the
|
||||||
range `[1...max_tokens+1)`; if False, the output will be in the range
|
range `[1...max_tokens+1)`; if False, the output will be in the range
|
||||||
@ -164,10 +171,38 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
self._inverse_table = None
|
self._inverse_table = None
|
||||||
|
|
||||||
if vocabulary is not None:
|
if vocabulary is not None:
|
||||||
self._export_vocab = True
|
if isinstance(vocabulary, str):
|
||||||
|
vocabulary = self._get_vocabulary_from_file(vocabulary)
|
||||||
|
|
||||||
|
vocabulary_set = set(vocabulary)
|
||||||
|
if len(vocabulary) != len(vocabulary_set):
|
||||||
|
repeated_items = [
|
||||||
|
item for item, count in collections.Counter(vocabulary).items()
|
||||||
|
if count > 1
|
||||||
|
]
|
||||||
|
raise ValueError("The passed vocabulary has at least one repeated "
|
||||||
|
"term. Please uniquify your dataset before passing "
|
||||||
|
"it to IndexLookup(). The repeated terms are %s" %
|
||||||
|
repeated_items)
|
||||||
self.set_vocabulary(vocabulary)
|
self.set_vocabulary(vocabulary)
|
||||||
else:
|
|
||||||
self._export_vocab = False
|
def _get_vocabulary_from_file(self, vocabulary_path):
|
||||||
|
vocab = []
|
||||||
|
with gfile.GFile(vocabulary_path, "r") as reader:
|
||||||
|
while True:
|
||||||
|
# Get the next line, and break if it is None.
|
||||||
|
text = reader.readline()
|
||||||
|
if not text:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert the raw text into UTF8 and strip whitespace.
|
||||||
|
if isinstance(text, str):
|
||||||
|
token = text
|
||||||
|
elif isinstance(text, bytes):
|
||||||
|
token = text.decode("utf-8", "ignore")
|
||||||
|
token = token.strip()
|
||||||
|
vocab.append(token)
|
||||||
|
return vocab
|
||||||
|
|
||||||
def _get_table_data(self):
|
def _get_table_data(self):
|
||||||
keys, values = self._table.export()
|
keys, values = self._table.export()
|
||||||
@ -256,11 +291,10 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
return [x for _, x in sorted(zip(values, keys))]
|
return [x for _, x in sorted(zip(values, keys))]
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
vocabulary = self.get_vocabulary() if self._export_vocab else None
|
|
||||||
config = {
|
config = {
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
"num_oov_tokens": self.num_oov_tokens,
|
"num_oov_tokens": self.num_oov_tokens,
|
||||||
"vocabulary": vocabulary,
|
"vocabulary": None,
|
||||||
"reserve_zero": self.reserve_zero,
|
"reserve_zero": self.reserve_zero,
|
||||||
"mask_zero": self.mask_zero,
|
"mask_zero": self.mask_zero,
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.keras.layers.preprocessing import preprocessing_test_util
|
|||||||
from tensorflow.python.keras.saving import save
|
from tensorflow.python.keras.saving import save
|
||||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -356,7 +357,22 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
|||||||
output_dataset = model.predict(input_array)
|
output_dataset = model.predict(input_array)
|
||||||
self.assertAllEqual(expected_output, output_dataset)
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
def test_int_output_explicit_vocab_from_config(self):
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
|
||||||
|
preprocessing_test_utils.PreprocessingLayerTest
|
||||||
|
):
|
||||||
|
|
||||||
|
def _write_to_temp_file(self, file_name, vocab_list):
|
||||||
|
vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt")
|
||||||
|
with gfile.GFile(vocab_path, "w") as writer:
|
||||||
|
for vocab in vocab_list:
|
||||||
|
writer.write(vocab + "\n")
|
||||||
|
writer.flush()
|
||||||
|
writer.close()
|
||||||
|
return vocab_path
|
||||||
|
|
||||||
|
def test_int_output_explicit_vocab(self):
|
||||||
vocab_data = ["earth", "wind", "and", "fire"]
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||||
["fire", "and", "earth", "michigan"]])
|
["fire", "and", "earth", "michigan"]])
|
||||||
@ -366,10 +382,22 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
|||||||
layer = get_layer_class()(vocabulary=vocab_data)
|
layer = get_layer_class()(vocabulary=vocab_data)
|
||||||
int_data = layer(input_data)
|
int_data = layer(input_data)
|
||||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||||
|
output_dataset = model.predict(input_array)
|
||||||
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
with CustomObjectScope({"IndexLookup": get_layer_class()}):
|
def test_int_output_explicit_vocab_from_file(self):
|
||||||
new_model = keras.Model.from_config(model.get_config())
|
vocab_list = ["earth", "wind", "and", "fire"]
|
||||||
output_dataset = new_model.predict(input_array)
|
vocab_path = self._write_to_temp_file("vocab_file", vocab_list)
|
||||||
|
|
||||||
|
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||||
|
["fire", "and", "earth", "michigan"]])
|
||||||
|
expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
|
||||||
|
|
||||||
|
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||||
|
layer = get_layer_class()(vocabulary=vocab_path)
|
||||||
|
int_data = layer(input_data)
|
||||||
|
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||||
|
output_dataset = model.predict(input_array)
|
||||||
self.assertAllEqual(expected_output, output_dataset)
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
def test_vocab_appending(self):
|
def test_vocab_appending(self):
|
||||||
@ -387,6 +415,17 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
|||||||
output_dataset = model.predict(input_array)
|
output_dataset = model.predict(input_array)
|
||||||
self.assertAllClose(expected_output, output_dataset)
|
self.assertAllClose(expected_output, output_dataset)
|
||||||
|
|
||||||
|
def test_non_unique_vocab_fails(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire", "fire"]
|
||||||
|
with self.assertRaisesRegex(ValueError, ".*repeated term.*fire.*"):
|
||||||
|
_ = get_layer_class()(vocabulary=vocab_data)
|
||||||
|
|
||||||
|
def test_non_unique_vocab_from_file_fails(self):
|
||||||
|
vocab_list = ["earth", "wind", "and", "fire", "earth"]
|
||||||
|
vocab_path = self._write_to_temp_file("repeat_vocab_file", vocab_list)
|
||||||
|
with self.assertRaisesRegex(ValueError, ".*repeated term.*earth.*"):
|
||||||
|
_ = get_layer_class()(vocabulary=vocab_path)
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class InverseLookupOutputTest(keras_parameterized.TestCase,
|
class InverseLookupOutputTest(keras_parameterized.TestCase,
|
||||||
|
Loading…
Reference in New Issue
Block a user