Update TextVectorization to use internal layer adapt calls instead of its own combiner.
PiperOrigin-RevId: 312292748 Change-Id: Ia157a06f55a28325dac9e4a58b3fed23fc4599d4
This commit is contained in:
parent
282db86128
commit
15d39f5e83
@ -55,8 +55,9 @@ class CombinerPreprocessingLayer(
|
|||||||
|
|
||||||
def _get_dataset_iterator(self, dataset):
|
def _get_dataset_iterator(self, dataset):
|
||||||
"""Gets an iterator from a tf.data.Dataset."""
|
"""Gets an iterator from a tf.data.Dataset."""
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||||
session = K.get_session()
|
session = K.get_session()
|
||||||
|
session.run(iterator.initializer)
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
return lambda: session.run(next_element)
|
return lambda: session.run(next_element)
|
||||||
|
|
||||||
|
@ -17,10 +17,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
|
||||||
import json
|
|
||||||
import operator
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
@ -29,7 +25,6 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
|
|
||||||
from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
|
from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
|
||||||
from tensorflow.python.keras.layers.preprocessing import categorical_encoding
|
from tensorflow.python.keras.layers.preprocessing import categorical_encoding
|
||||||
from tensorflow.python.keras.layers.preprocessing import string_lookup
|
from tensorflow.python.keras.layers.preprocessing import string_lookup
|
||||||
@ -41,7 +36,6 @@ from tensorflow.python.ops import string_ops
|
|||||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.util import compat
|
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation"
|
LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation"
|
||||||
@ -122,7 +116,9 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
max_tokens: The maximum size of the vocabulary for this layer. If None,
|
max_tokens: The maximum size of the vocabulary for this layer. If None,
|
||||||
there is no cap on the size of the vocabulary.
|
there is no cap on the size of the vocabulary. Note that this vocabulary
|
||||||
|
contains 1 OOV token, so the effective number of tokens is `(max_tokens -
|
||||||
|
1 - (1 if output == "int" else 0))`
|
||||||
standardize: Optional specification for standardization to apply to the
|
standardize: Optional specification for standardization to apply to the
|
||||||
input text. Values can be None (no standardization),
|
input text. Values can be None (no standardization),
|
||||||
'lower_and_strip_punctuation' (lowercase and remove punctuation) or a
|
'lower_and_strip_punctuation' (lowercase and remove punctuation) or a
|
||||||
@ -138,7 +134,8 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
output_mode: Optional specification for the output of the layer. Values can
|
output_mode: Optional specification for the output of the layer. Values can
|
||||||
be "int", "binary", "count" or "tf-idf", configuring the layer as follows:
|
be "int", "binary", "count" or "tf-idf", configuring the layer as follows:
|
||||||
"int": Outputs integer indices, one integer index per split string
|
"int": Outputs integer indices, one integer index per split string
|
||||||
token.
|
token. When output == "int", 0 is reserved for masked locations;
|
||||||
|
this reduces the vocab size to max_tokens-2 instead of max_tokens-1
|
||||||
"binary": Outputs a single int array per batch, of either vocab_size or
|
"binary": Outputs a single int array per batch, of either vocab_size or
|
||||||
max_tokens size, containing 1s in all elements where the token mapped
|
max_tokens size, containing 1s in all elements where the token mapped
|
||||||
to that index exists at least once in the batch item.
|
to that index exists at least once in the batch item.
|
||||||
@ -274,12 +271,6 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
# the OOV value to zero instead of one.
|
# the OOV value to zero instead of one.
|
||||||
self._oov_value = 1 if output_mode == INT else 0
|
self._oov_value = 1 if output_mode == INT else 0
|
||||||
|
|
||||||
# We always reduce the max token number by 1 to account for the OOV token
|
|
||||||
# if it is set. Keras' use of the reserved number 0 for padding tokens,
|
|
||||||
# if the output is in INT mode, does not really count as a 'token' for
|
|
||||||
# vocabulary purposes, so we only reduce vocab size by 1 here.
|
|
||||||
self._max_vocab_size = max_tokens - 1 if max_tokens is not None else None
|
|
||||||
|
|
||||||
self._standardize = standardize
|
self._standardize = standardize
|
||||||
self._split = split
|
self._split = split
|
||||||
self._ngrams_arg = ngrams
|
self._ngrams_arg = ngrams
|
||||||
@ -295,8 +286,7 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
self._called = False
|
self._called = False
|
||||||
|
|
||||||
super(TextVectorization, self).__init__(
|
super(TextVectorization, self).__init__(
|
||||||
combiner=_TextVectorizationCombiner(
|
combiner=None,
|
||||||
self._max_vocab_size, compute_idf=output_mode == TFIDF),
|
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
mask_token = "" if output_mode in [None, INT] else None
|
mask_token = "" if output_mode in [None, INT] else None
|
||||||
@ -306,14 +296,14 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
# If this layer is configured for string or integer output, we do not
|
# If this layer is configured for string or integer output, we do not
|
||||||
# create a vectorization layer (as the output is not vectorized).
|
# create a vectorization layer (as the output is not vectorized).
|
||||||
if self._output_mode in [None, INT]:
|
if self._output_mode in [None, INT]:
|
||||||
return
|
self._vectorize_layer = None
|
||||||
|
|
||||||
if max_tokens is not None and self._pad_to_max:
|
|
||||||
max_elements = max_tokens
|
|
||||||
else:
|
else:
|
||||||
max_elements = None
|
if max_tokens is not None and self._pad_to_max:
|
||||||
self._vectorize_layer = self._get_vectorization_class()(
|
max_elements = max_tokens
|
||||||
max_tokens=max_elements, output_mode=self._output_mode)
|
else:
|
||||||
|
max_elements = None
|
||||||
|
self._vectorize_layer = self._get_vectorization_class()(
|
||||||
|
max_tokens=max_elements, output_mode=self._output_mode)
|
||||||
|
|
||||||
# These are V1/V2 shim points. There are V1 implementations in the V1 class.
|
# These are V1/V2 shim points. There are V1 implementations in the V1 class.
|
||||||
def _get_vectorization_class(self):
|
def _get_vectorization_class(self):
|
||||||
@ -407,7 +397,14 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"adapt() requires a Dataset or an array as input, got {}".format(
|
"adapt() requires a Dataset or an array as input, got {}".format(
|
||||||
type(data)))
|
type(data)))
|
||||||
super(TextVectorization, self).adapt(preprocessed_inputs, reset_state)
|
|
||||||
|
self._index_lookup_layer.adapt(preprocessed_inputs)
|
||||||
|
if self._vectorize_layer:
|
||||||
|
if isinstance(data, ops.Tensor):
|
||||||
|
integer_data = self._index_lookup_layer(preprocessed_inputs)
|
||||||
|
else:
|
||||||
|
integer_data = preprocessed_inputs.map(self._index_lookup_layer)
|
||||||
|
self._vectorize_layer.adapt(integer_data)
|
||||||
|
|
||||||
def get_vocabulary(self):
|
def get_vocabulary(self):
|
||||||
return self._index_lookup_layer.get_vocabulary()
|
return self._index_lookup_layer.get_vocabulary()
|
||||||
@ -616,191 +613,3 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
# If we're not returning integers here, we rely on the vectorization layer
|
# If we're not returning integers here, we rely on the vectorization layer
|
||||||
# to create the output.
|
# to create the output.
|
||||||
return self._vectorize_layer(indexed_data)
|
return self._vectorize_layer(indexed_data)
|
||||||
|
|
||||||
|
|
||||||
class _TextVectorizationAccumulator(
|
|
||||||
collections.namedtuple("_TextVectorizationAccumulator",
|
|
||||||
["count_dict", "per_doc_count_dict", "metadata"])):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# A note on this combiner: This contains functionality that will be extracted
|
|
||||||
# into the Vectorization and IndexLookup combiner objects. At that point,
|
|
||||||
# TextVectorization can become a PreprocessingStage instead of a Layer and
|
|
||||||
# this combiner can be retired. Until then, we leave this as is instead of
|
|
||||||
# attempting a refactor of what will soon be deleted.
|
|
||||||
class _TextVectorizationCombiner(Combiner):
|
|
||||||
"""Combiner for the TextVectorization preprocessing layer.
|
|
||||||
|
|
||||||
This class encapsulates the logic for computing a vocabulary based on the
|
|
||||||
frequency of each token.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
vocab_size: (Optional) If set, only the top `vocab_size` tokens (based on
|
|
||||||
frequency across the dataset) are retained in the vocabulary. If None, or
|
|
||||||
set to a value greater than the total number of distinct tokens in the
|
|
||||||
dataset, all tokens are retained.
|
|
||||||
compute_idf: (Optional) If set, the inverse document frequency will be
|
|
||||||
computed for each value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vocab_size=None, compute_idf=False):
|
|
||||||
self._vocab_size = vocab_size
|
|
||||||
self._compute_idf = compute_idf
|
|
||||||
self._input_dtype = dtypes.string
|
|
||||||
|
|
||||||
def compute(self, values, accumulator=None):
|
|
||||||
"""Compute a step in this computation, returning a new accumulator."""
|
|
||||||
if dtypes.as_dtype(self._input_dtype) != dtypes.as_dtype(values.dtype):
|
|
||||||
raise RuntimeError("Expected input type %s, got %s" %
|
|
||||||
(self._input_dtype, values.dtype))
|
|
||||||
if ragged_tensor.is_ragged(values):
|
|
||||||
values = values.to_list()
|
|
||||||
if isinstance(values, ops.EagerTensor):
|
|
||||||
values = values.numpy()
|
|
||||||
if isinstance(values, np.ndarray):
|
|
||||||
values = values.tolist()
|
|
||||||
|
|
||||||
if accumulator is None:
|
|
||||||
accumulator = self._create_accumulator()
|
|
||||||
|
|
||||||
# If we are being passed raw strings or bytestrings, we need to wrap them
|
|
||||||
# in an array so we don't accidentally iterate over the bytes instead of
|
|
||||||
# treating the string as one object.
|
|
||||||
if isinstance(values, (str, bytes)):
|
|
||||||
values = [values]
|
|
||||||
|
|
||||||
# TODO(momernick): Benchmark improvements to this algorithm.
|
|
||||||
for document in values:
|
|
||||||
current_doc_id = accumulator.metadata[0]
|
|
||||||
for token in document:
|
|
||||||
accumulator.count_dict[token] += 1
|
|
||||||
if self._compute_idf:
|
|
||||||
doc_count = accumulator.per_doc_count_dict[token]
|
|
||||||
if doc_count["last_doc_id"] != current_doc_id:
|
|
||||||
doc_count["count"] += 1
|
|
||||||
doc_count["last_doc_id"] = current_doc_id
|
|
||||||
accumulator.metadata[0] += 1
|
|
||||||
|
|
||||||
return accumulator
|
|
||||||
|
|
||||||
def merge(self, accumulators):
|
|
||||||
"""Merge several accumulators to a single accumulator."""
|
|
||||||
if not accumulators:
|
|
||||||
return accumulators
|
|
||||||
|
|
||||||
base_accumulator = accumulators[0]
|
|
||||||
|
|
||||||
for accumulator in accumulators[1:]:
|
|
||||||
base_accumulator.metadata[0] += accumulator.metadata[0]
|
|
||||||
for token, value in accumulator.count_dict.items():
|
|
||||||
base_accumulator.count_dict[token] += value
|
|
||||||
if self._compute_idf:
|
|
||||||
for token, value in accumulator.per_doc_count_dict.items():
|
|
||||||
# Any newly created token counts in 'base_accumulator''s
|
|
||||||
# per_doc_count_dict will have a last_doc_id of -1. This is always
|
|
||||||
# less than the next doc id (which are strictly positive), so any
|
|
||||||
# future occurrences are guaranteed to be counted.
|
|
||||||
base_accumulator.per_doc_count_dict[token]["count"] += value["count"]
|
|
||||||
|
|
||||||
return base_accumulator
|
|
||||||
|
|
||||||
def _inverse_document_frequency(self, document_counts, num_documents):
|
|
||||||
"""Compute the inverse-document-frequency (IDF) component of TFIDF.
|
|
||||||
|
|
||||||
Uses the default weighting scheme described in
|
|
||||||
https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document_counts: An array of the # of documents each token appears in.
|
|
||||||
num_documents: An int representing the total number of documents
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An array of "inverse document frequency" weights.
|
|
||||||
"""
|
|
||||||
return np.log(1 + num_documents / (1 + np.array(document_counts)))
|
|
||||||
|
|
||||||
def extract(self, accumulator):
|
|
||||||
"""Convert an accumulator into a dict of output values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
accumulator: An accumulator aggregating over the full dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict of:
|
|
||||||
"vocab": A list of the retained items in the vocabulary.
|
|
||||||
"idf": The inverse-document-frequency for each item in vocab.
|
|
||||||
idf[vocab_idx] is the IDF value for the corresponding vocab item.
|
|
||||||
"oov_idf": The inverse-document-frequency for the OOV token.
|
|
||||||
"""
|
|
||||||
if self._compute_idf:
|
|
||||||
vocab_counts, document_counts, num_documents = accumulator
|
|
||||||
else:
|
|
||||||
vocab_counts, _, _ = accumulator
|
|
||||||
|
|
||||||
sorted_counts = sorted(
|
|
||||||
vocab_counts.items(), key=operator.itemgetter(1, 0), reverse=True)
|
|
||||||
vocab_data = (
|
|
||||||
sorted_counts[:self._vocab_size] if self._vocab_size else sorted_counts)
|
|
||||||
vocab = [data[0] for data in vocab_data]
|
|
||||||
|
|
||||||
if self._compute_idf:
|
|
||||||
doc_counts = [document_counts[token]["count"] for token in vocab]
|
|
||||||
idf = self._inverse_document_frequency(doc_counts, num_documents[0])
|
|
||||||
oov_idf = np.array([np.log(1 + num_documents[0])])
|
|
||||||
return {_VOCAB_NAME: vocab, _IDF_NAME: idf, _OOV_IDF_NAME: oov_idf}
|
|
||||||
else:
|
|
||||||
return {_VOCAB_NAME: vocab}
|
|
||||||
|
|
||||||
def restore(self, output):
|
|
||||||
"""Create an accumulator based on 'output'."""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"TextVectorization does not restore or support streaming updates.")
|
|
||||||
|
|
||||||
def serialize(self, accumulator):
|
|
||||||
"""Serialize an accumulator for a remote call."""
|
|
||||||
output_dict = {}
|
|
||||||
output_dict["metadata"] = accumulator.metadata
|
|
||||||
output_dict["vocab"] = list(accumulator.count_dict.keys())
|
|
||||||
output_dict["vocab_counts"] = list(accumulator.count_dict.values())
|
|
||||||
if self._compute_idf:
|
|
||||||
output_dict["idf_vocab"] = list(accumulator.per_doc_count_dict.keys())
|
|
||||||
output_dict["idf_counts"] = [
|
|
||||||
counter["count"]
|
|
||||||
for counter in accumulator.per_doc_count_dict.values()
|
|
||||||
]
|
|
||||||
return compat.as_bytes(json.dumps(output_dict))
|
|
||||||
|
|
||||||
def deserialize(self, encoded_accumulator):
|
|
||||||
"""Deserialize an accumulator received from 'serialize()'."""
|
|
||||||
accumulator_dict = json.loads(compat.as_text(encoded_accumulator))
|
|
||||||
|
|
||||||
accumulator = self._create_accumulator()
|
|
||||||
accumulator.metadata[0] = accumulator_dict["metadata"][0]
|
|
||||||
|
|
||||||
count_dict = dict(
|
|
||||||
zip(accumulator_dict["vocab"], accumulator_dict["vocab_counts"]))
|
|
||||||
accumulator.count_dict.update(count_dict)
|
|
||||||
|
|
||||||
if self._compute_idf:
|
|
||||||
create_dict = lambda x: {"count": x, "last_doc_id": -1}
|
|
||||||
idf_count_dicts = [
|
|
||||||
create_dict(count) for count in accumulator_dict["idf_counts"]
|
|
||||||
]
|
|
||||||
idf_dict = dict(zip(accumulator_dict["idf_vocab"], idf_count_dicts))
|
|
||||||
accumulator.per_doc_count_dict.update(idf_dict)
|
|
||||||
|
|
||||||
return accumulator
|
|
||||||
|
|
||||||
def _create_accumulator(self):
|
|
||||||
"""Accumulate a sorted array of vocab tokens and corresponding counts."""
|
|
||||||
|
|
||||||
count_dict = collections.defaultdict(int)
|
|
||||||
if self._compute_idf:
|
|
||||||
create_default_dict = lambda: {"count": 0, "last_doc_id": -1}
|
|
||||||
per_doc_count_dict = collections.defaultdict(create_default_dict)
|
|
||||||
else:
|
|
||||||
per_doc_count_dict = None
|
|
||||||
metadata = [0]
|
|
||||||
return _TextVectorizationAccumulator(count_dict, per_doc_count_dict,
|
|
||||||
metadata)
|
|
||||||
|
@ -62,7 +62,7 @@ def _get_end_to_end_test_cases():
|
|||||||
"testcase_name":
|
"testcase_name":
|
||||||
"test_simple_tokens_int_mode",
|
"test_simple_tokens_int_mode",
|
||||||
# Create an array where 'earth' is the most frequent term, followed by
|
# Create an array where 'earth' is the most frequent term, followed by
|
||||||
# 'wind', then 'and', then 'fire'. This ensures that the vocab accumulator
|
# 'wind', then 'and', then 'fire'. This ensures that the vocab
|
||||||
# is sorting by frequency.
|
# is sorting by frequency.
|
||||||
"vocab_data":
|
"vocab_data":
|
||||||
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
||||||
@ -78,6 +78,26 @@ def _get_end_to_end_test_cases():
|
|||||||
},
|
},
|
||||||
"expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
|
"expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"test_simple_tokens_int_mode_hard_cap",
|
||||||
|
# Create an array where 'earth' is the most frequent term, followed by
|
||||||
|
# 'wind', then 'and', then 'fire'. This ensures that the vocab
|
||||||
|
# is sorting by frequency.
|
||||||
|
"vocab_data":
|
||||||
|
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
||||||
|
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
|
||||||
|
"input_data":
|
||||||
|
np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"],
|
||||||
|
["and"], ["earth"], ["michigan"]]),
|
||||||
|
"kwargs": {
|
||||||
|
"max_tokens": 6,
|
||||||
|
"standardize": None,
|
||||||
|
"split": None,
|
||||||
|
"output_mode": text_vectorization.INT
|
||||||
|
},
|
||||||
|
"expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"testcase_name":
|
"testcase_name":
|
||||||
"test_documents_int_mode",
|
"test_documents_int_mode",
|
||||||
@ -985,7 +1005,7 @@ class TextVectorizationOutputTest(
|
|||||||
output_mode=text_vectorization.BINARY,
|
output_mode=text_vectorization.BINARY,
|
||||||
pad_to_max_tokens=False)
|
pad_to_max_tokens=False)
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
|
with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"):
|
||||||
layer.adapt(vocab_data)
|
layer.adapt(vocab_data)
|
||||||
|
|
||||||
def test_bag_output_soft_maximum_set_state_variables_after_call_fails(self):
|
def test_bag_output_soft_maximum_set_state_variables_after_call_fails(self):
|
||||||
@ -1347,6 +1367,7 @@ class TextVectorizationErrorTest(keras_parameterized.TestCase,
|
|||||||
".*`output_sequence_length` must not be set.*"):
|
".*`output_sequence_length` must not be set.*"):
|
||||||
_ = get_layer_class()(output_mode="count", output_sequence_length=2)
|
_ = get_layer_class()(output_mode="count", output_sequence_length=2)
|
||||||
|
|
||||||
|
|
||||||
# Custom functions for the custom callable serialization test. Declared here
|
# Custom functions for the custom callable serialization test. Declared here
|
||||||
# to avoid multiple registrations from run_all_keras_modes().
|
# to avoid multiple registrations from run_all_keras_modes().
|
||||||
@generic_utils.register_keras_serializable(package="Test")
|
@generic_utils.register_keras_serializable(package="Test")
|
||||||
@ -1528,208 +1549,5 @@ class TextVectorizationSavingTest(
|
|||||||
self.assertAllClose(new_output_dataset, expected_output)
|
self.assertAllClose(new_output_dataset, expected_output)
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
|
||||||
class TextVectorizationCombinerTest(
|
|
||||||
keras_parameterized.TestCase,
|
|
||||||
preprocessing_test_utils.PreprocessingLayerTest):
|
|
||||||
|
|
||||||
def compare_text_accumulators(self, a, b, msg=None):
|
|
||||||
if a is None or b is None:
|
|
||||||
self.assertAllEqual(a, b, msg=msg)
|
|
||||||
|
|
||||||
self.assertAllEqual(a.count_dict, b.count_dict, msg=msg)
|
|
||||||
self.assertAllEqual(a.metadata, b.metadata, msg=msg)
|
|
||||||
|
|
||||||
if a.per_doc_count_dict is not None:
|
|
||||||
|
|
||||||
def per_doc_counts(accumulator):
|
|
||||||
count_values = [
|
|
||||||
count_dict["count"]
|
|
||||||
for count_dict in accumulator.per_doc_count_dict.values()
|
|
||||||
]
|
|
||||||
return dict(zip(accumulator.per_doc_count_dict.keys(), count_values))
|
|
||||||
|
|
||||||
self.assertAllEqual(per_doc_counts(a), per_doc_counts(b), msg=msg)
|
|
||||||
|
|
||||||
compare_accumulators = compare_text_accumulators
|
|
||||||
|
|
||||||
def update_accumulator(self, accumulator, data):
|
|
||||||
accumulator.count_dict.update(dict(zip(data["vocab"], data["counts"])))
|
|
||||||
accumulator.metadata[0] = data["num_documents"]
|
|
||||||
|
|
||||||
if "document_counts" in data:
|
|
||||||
create_dict = lambda x: {"count": x, "last_doc_id": -1}
|
|
||||||
idf_count_dicts = [
|
|
||||||
create_dict(count) for count in data["document_counts"]
|
|
||||||
]
|
|
||||||
idf_dict = dict(zip(data["vocab"], idf_count_dicts))
|
|
||||||
|
|
||||||
accumulator.per_doc_count_dict.update(idf_dict)
|
|
||||||
|
|
||||||
return accumulator
|
|
||||||
|
|
||||||
def test_combiner_api_compatibility_int_mode(self):
|
|
||||||
data = np.array([["earth", "wind", "and", "fire"],
|
|
||||||
["earth", "wind", "and", "michigan"]])
|
|
||||||
combiner = text_vectorization._TextVectorizationCombiner(compute_idf=False)
|
|
||||||
expected_accumulator_output = {
|
|
||||||
"vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
|
|
||||||
"counts": np.array([2, 2, 2, 1, 1]),
|
|
||||||
"num_documents": np.array(2),
|
|
||||||
}
|
|
||||||
expected_extract_output = {
|
|
||||||
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
|
|
||||||
}
|
|
||||||
expected_accumulator = combiner._create_accumulator()
|
|
||||||
expected_accumulator = self.update_accumulator(expected_accumulator,
|
|
||||||
expected_accumulator_output)
|
|
||||||
self.validate_accumulator_serialize_and_deserialize(combiner, data,
|
|
||||||
expected_accumulator)
|
|
||||||
self.validate_accumulator_uniqueness(combiner, data)
|
|
||||||
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
|
||||||
|
|
||||||
def test_combiner_api_compatibility_tfidf_mode(self):
|
|
||||||
data = np.array([["earth", "wind", "and", "fire"],
|
|
||||||
["earth", "wind", "and", "michigan"]])
|
|
||||||
combiner = text_vectorization._TextVectorizationCombiner(compute_idf=True)
|
|
||||||
expected_extract_output = {
|
|
||||||
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
|
|
||||||
"idf": np.array([0.510826, 0.510826, 0.510826, 0.693147, 0.693147]),
|
|
||||||
"oov_idf": np.array([1.098612])
|
|
||||||
}
|
|
||||||
expected_accumulator_output = {
|
|
||||||
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
|
|
||||||
"counts": np.array([2, 2, 2, 1, 1]),
|
|
||||||
"document_counts": np.array([2, 2, 2, 1, 1]),
|
|
||||||
"num_documents": np.array(2),
|
|
||||||
}
|
|
||||||
|
|
||||||
expected_accumulator = combiner._create_accumulator()
|
|
||||||
expected_accumulator = self.update_accumulator(expected_accumulator,
|
|
||||||
expected_accumulator_output)
|
|
||||||
self.validate_accumulator_serialize_and_deserialize(combiner, data,
|
|
||||||
expected_accumulator)
|
|
||||||
self.validate_accumulator_uniqueness(combiner, data)
|
|
||||||
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
|
||||||
|
|
||||||
# TODO(askerryryan): Add tests confirming equivalence to behavior of
|
|
||||||
# existing tf.keras.preprocessing.text.Tokenizer.
|
|
||||||
@parameterized.named_parameters(
|
|
||||||
{
|
|
||||||
"testcase_name":
|
|
||||||
"top_k_smaller_than_full_vocab",
|
|
||||||
"data":
|
|
||||||
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
|
|
||||||
["fire", "wind"]]),
|
|
||||||
"vocab_size":
|
|
||||||
3,
|
|
||||||
"expected_accumulator_output": {
|
|
||||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
|
||||||
"counts": np.array([3, 2, 1, 1]),
|
|
||||||
"document_counts": np.array([3, 2, 1, 1]),
|
|
||||||
"num_documents": np.array(4),
|
|
||||||
},
|
|
||||||
"expected_extract_output": {
|
|
||||||
"vocab": np.array(["wind", "fire", "earth"]),
|
|
||||||
"idf": np.array([0.693147, 0.847298, 1.098612]),
|
|
||||||
"oov_idf": np.array([1.609438]),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"testcase_name":
|
|
||||||
"top_k_larger_than_full_vocab",
|
|
||||||
"data":
|
|
||||||
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
|
|
||||||
["fire", "wind"]]),
|
|
||||||
"vocab_size":
|
|
||||||
10,
|
|
||||||
"expected_accumulator_output": {
|
|
||||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
|
||||||
"counts": np.array([3, 2, 1, 1]),
|
|
||||||
"document_counts": np.array([3, 2, 1, 1]),
|
|
||||||
"num_documents": np.array(4),
|
|
||||||
},
|
|
||||||
"expected_extract_output": {
|
|
||||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
|
||||||
"idf": np.array([0.693147, 0.847298, 1.098612, 1.098612]),
|
|
||||||
"oov_idf": np.array([1.609438]),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"testcase_name":
|
|
||||||
"no_top_k",
|
|
||||||
"data":
|
|
||||||
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
|
|
||||||
["fire", "wind"]]),
|
|
||||||
"vocab_size":
|
|
||||||
None,
|
|
||||||
"expected_accumulator_output": {
|
|
||||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
|
||||||
"counts": np.array([3, 2, 1, 1]),
|
|
||||||
"document_counts": np.array([3, 2, 1, 1]),
|
|
||||||
"num_documents": np.array(4),
|
|
||||||
},
|
|
||||||
"expected_extract_output": {
|
|
||||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
|
||||||
"idf": np.array([0.693147, 0.847298, 1.098612, 1.098612]),
|
|
||||||
"oov_idf": np.array([1.609438]),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"testcase_name": "single_element_per_row",
|
|
||||||
"data": np.array([["earth"], ["wind"], ["fire"], ["wind"], ["and"]]),
|
|
||||||
"vocab_size": 3,
|
|
||||||
"expected_accumulator_output": {
|
|
||||||
"vocab": np.array(["wind", "and", "earth", "fire"]),
|
|
||||||
"counts": np.array([2, 1, 1, 1]),
|
|
||||||
"document_counts": np.array([2, 1, 1, 1]),
|
|
||||||
"num_documents": np.array(5),
|
|
||||||
},
|
|
||||||
"expected_extract_output": {
|
|
||||||
"vocab": np.array(["wind", "fire", "earth"]),
|
|
||||||
"idf": np.array([0.980829, 1.252763, 1.252763]),
|
|
||||||
"oov_idf": np.array([1.791759]),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
# Which tokens are retained are based on global frequency, and thus are
|
|
||||||
# sensitive to frequency within a document. In contrast, because idf only
|
|
||||||
# considers the presence of a token in a document, it is insensitive
|
|
||||||
# to the frequency of the token within the document.
|
|
||||||
{
|
|
||||||
"testcase_name":
|
|
||||||
"retained_tokens_sensitive_to_within_document_frequency",
|
|
||||||
"data":
|
|
||||||
np.array([["earth", "earth"], ["wind", "wind"], ["fire", "fire"],
|
|
||||||
["wind", "wind"], ["and", "michigan"]]),
|
|
||||||
"vocab_size":
|
|
||||||
3,
|
|
||||||
"expected_accumulator_output": {
|
|
||||||
"vocab": np.array(["wind", "earth", "fire", "and", "michigan"]),
|
|
||||||
"counts": np.array([4, 2, 2, 1, 1]),
|
|
||||||
"document_counts": np.array([2, 1, 1, 1, 1]),
|
|
||||||
"num_documents": np.array(5),
|
|
||||||
},
|
|
||||||
"expected_extract_output": {
|
|
||||||
"vocab": np.array(["wind", "fire", "earth"]),
|
|
||||||
"idf": np.array([0.980829, 1.252763, 1.252763]),
|
|
||||||
"oov_idf": np.array([1.791759]),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
def test_combiner_computation(self,
|
|
||||||
data,
|
|
||||||
vocab_size,
|
|
||||||
expected_accumulator_output,
|
|
||||||
expected_extract_output,
|
|
||||||
compute_idf=True):
|
|
||||||
combiner = text_vectorization._TextVectorizationCombiner(
|
|
||||||
vocab_size=vocab_size, compute_idf=compute_idf)
|
|
||||||
expected_accumulator = combiner._create_accumulator()
|
|
||||||
expected_accumulator = self.update_accumulator(expected_accumulator,
|
|
||||||
expected_accumulator_output)
|
|
||||||
self.validate_accumulator_computation(combiner, data, expected_accumulator)
|
|
||||||
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user