Update TextVectorization to use internal layer adapt calls instead of its own combiner.
PiperOrigin-RevId: 312292748 Change-Id: Ia157a06f55a28325dac9e4a58b3fed23fc4599d4
This commit is contained in:
parent
282db86128
commit
15d39f5e83
@ -55,8 +55,9 @@ class CombinerPreprocessingLayer(
|
||||
|
||||
def _get_dataset_iterator(self, dataset):
|
||||
"""Gets an iterator from a tf.data.Dataset."""
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
session = K.get_session()
|
||||
session.run(iterator.initializer)
|
||||
next_element = iterator.get_next()
|
||||
return lambda: session.run(next_element)
|
||||
|
||||
|
@ -17,10 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -29,7 +25,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
|
||||
from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
|
||||
from tensorflow.python.keras.layers.preprocessing import categorical_encoding
|
||||
from tensorflow.python.keras.layers.preprocessing import string_lookup
|
||||
@ -41,7 +36,6 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation"
|
||||
@ -122,7 +116,9 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
|
||||
Attributes:
|
||||
max_tokens: The maximum size of the vocabulary for this layer. If None,
|
||||
there is no cap on the size of the vocabulary.
|
||||
there is no cap on the size of the vocabulary. Note that this vocabulary
|
||||
contains 1 OOV token, so the effective number of tokens is `(max_tokens -
|
||||
1 - (1 if output == "int" else 0))`
|
||||
standardize: Optional specification for standardization to apply to the
|
||||
input text. Values can be None (no standardization),
|
||||
'lower_and_strip_punctuation' (lowercase and remove punctuation) or a
|
||||
@ -138,7 +134,8 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
output_mode: Optional specification for the output of the layer. Values can
|
||||
be "int", "binary", "count" or "tf-idf", configuring the layer as follows:
|
||||
"int": Outputs integer indices, one integer index per split string
|
||||
token.
|
||||
token. When output == "int", 0 is reserved for masked locations;
|
||||
this reduces the vocab size to max_tokens-2 instead of max_tokens-1
|
||||
"binary": Outputs a single int array per batch, of either vocab_size or
|
||||
max_tokens size, containing 1s in all elements where the token mapped
|
||||
to that index exists at least once in the batch item.
|
||||
@ -274,12 +271,6 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
# the OOV value to zero instead of one.
|
||||
self._oov_value = 1 if output_mode == INT else 0
|
||||
|
||||
# We always reduce the max token number by 1 to account for the OOV token
|
||||
# if it is set. Keras' use of the reserved number 0 for padding tokens,
|
||||
# if the output is in INT mode, does not really count as a 'token' for
|
||||
# vocabulary purposes, so we only reduce vocab size by 1 here.
|
||||
self._max_vocab_size = max_tokens - 1 if max_tokens is not None else None
|
||||
|
||||
self._standardize = standardize
|
||||
self._split = split
|
||||
self._ngrams_arg = ngrams
|
||||
@ -295,8 +286,7 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
self._called = False
|
||||
|
||||
super(TextVectorization, self).__init__(
|
||||
combiner=_TextVectorizationCombiner(
|
||||
self._max_vocab_size, compute_idf=output_mode == TFIDF),
|
||||
combiner=None,
|
||||
**kwargs)
|
||||
|
||||
mask_token = "" if output_mode in [None, INT] else None
|
||||
@ -306,8 +296,8 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
# If this layer is configured for string or integer output, we do not
|
||||
# create a vectorization layer (as the output is not vectorized).
|
||||
if self._output_mode in [None, INT]:
|
||||
return
|
||||
|
||||
self._vectorize_layer = None
|
||||
else:
|
||||
if max_tokens is not None and self._pad_to_max:
|
||||
max_elements = max_tokens
|
||||
else:
|
||||
@ -407,7 +397,14 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
raise ValueError(
|
||||
"adapt() requires a Dataset or an array as input, got {}".format(
|
||||
type(data)))
|
||||
super(TextVectorization, self).adapt(preprocessed_inputs, reset_state)
|
||||
|
||||
self._index_lookup_layer.adapt(preprocessed_inputs)
|
||||
if self._vectorize_layer:
|
||||
if isinstance(data, ops.Tensor):
|
||||
integer_data = self._index_lookup_layer(preprocessed_inputs)
|
||||
else:
|
||||
integer_data = preprocessed_inputs.map(self._index_lookup_layer)
|
||||
self._vectorize_layer.adapt(integer_data)
|
||||
|
||||
def get_vocabulary(self):
|
||||
return self._index_lookup_layer.get_vocabulary()
|
||||
@ -616,191 +613,3 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
# If we're not returning integers here, we rely on the vectorization layer
|
||||
# to create the output.
|
||||
return self._vectorize_layer(indexed_data)
|
||||
|
||||
|
||||
class _TextVectorizationAccumulator(
|
||||
collections.namedtuple("_TextVectorizationAccumulator",
|
||||
["count_dict", "per_doc_count_dict", "metadata"])):
|
||||
pass
|
||||
|
||||
|
||||
# A note on this combiner: This contains functionality that will be extracted
|
||||
# into the Vectorization and IndexLookup combiner objects. At that point,
|
||||
# TextVectorization can become a PreprocessingStage instead of a Layer and
|
||||
# this combiner can be retired. Until then, we leave this as is instead of
|
||||
# attempting a refactor of what will soon be deleted.
|
||||
class _TextVectorizationCombiner(Combiner):
|
||||
"""Combiner for the TextVectorization preprocessing layer.
|
||||
|
||||
This class encapsulates the logic for computing a vocabulary based on the
|
||||
frequency of each token.
|
||||
|
||||
Attributes:
|
||||
vocab_size: (Optional) If set, only the top `vocab_size` tokens (based on
|
||||
frequency across the dataset) are retained in the vocabulary. If None, or
|
||||
set to a value greater than the total number of distinct tokens in the
|
||||
dataset, all tokens are retained.
|
||||
compute_idf: (Optional) If set, the inverse document frequency will be
|
||||
computed for each value.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size=None, compute_idf=False):
|
||||
self._vocab_size = vocab_size
|
||||
self._compute_idf = compute_idf
|
||||
self._input_dtype = dtypes.string
|
||||
|
||||
def compute(self, values, accumulator=None):
|
||||
"""Compute a step in this computation, returning a new accumulator."""
|
||||
if dtypes.as_dtype(self._input_dtype) != dtypes.as_dtype(values.dtype):
|
||||
raise RuntimeError("Expected input type %s, got %s" %
|
||||
(self._input_dtype, values.dtype))
|
||||
if ragged_tensor.is_ragged(values):
|
||||
values = values.to_list()
|
||||
if isinstance(values, ops.EagerTensor):
|
||||
values = values.numpy()
|
||||
if isinstance(values, np.ndarray):
|
||||
values = values.tolist()
|
||||
|
||||
if accumulator is None:
|
||||
accumulator = self._create_accumulator()
|
||||
|
||||
# If we are being passed raw strings or bytestrings, we need to wrap them
|
||||
# in an array so we don't accidentally iterate over the bytes instead of
|
||||
# treating the string as one object.
|
||||
if isinstance(values, (str, bytes)):
|
||||
values = [values]
|
||||
|
||||
# TODO(momernick): Benchmark improvements to this algorithm.
|
||||
for document in values:
|
||||
current_doc_id = accumulator.metadata[0]
|
||||
for token in document:
|
||||
accumulator.count_dict[token] += 1
|
||||
if self._compute_idf:
|
||||
doc_count = accumulator.per_doc_count_dict[token]
|
||||
if doc_count["last_doc_id"] != current_doc_id:
|
||||
doc_count["count"] += 1
|
||||
doc_count["last_doc_id"] = current_doc_id
|
||||
accumulator.metadata[0] += 1
|
||||
|
||||
return accumulator
|
||||
|
||||
def merge(self, accumulators):
|
||||
"""Merge several accumulators to a single accumulator."""
|
||||
if not accumulators:
|
||||
return accumulators
|
||||
|
||||
base_accumulator = accumulators[0]
|
||||
|
||||
for accumulator in accumulators[1:]:
|
||||
base_accumulator.metadata[0] += accumulator.metadata[0]
|
||||
for token, value in accumulator.count_dict.items():
|
||||
base_accumulator.count_dict[token] += value
|
||||
if self._compute_idf:
|
||||
for token, value in accumulator.per_doc_count_dict.items():
|
||||
# Any newly created token counts in 'base_accumulator''s
|
||||
# per_doc_count_dict will have a last_doc_id of -1. This is always
|
||||
# less than the next doc id (which are strictly positive), so any
|
||||
# future occurrences are guaranteed to be counted.
|
||||
base_accumulator.per_doc_count_dict[token]["count"] += value["count"]
|
||||
|
||||
return base_accumulator
|
||||
|
||||
def _inverse_document_frequency(self, document_counts, num_documents):
|
||||
"""Compute the inverse-document-frequency (IDF) component of TFIDF.
|
||||
|
||||
Uses the default weighting scheme described in
|
||||
https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
|
||||
|
||||
Args:
|
||||
document_counts: An array of the # of documents each token appears in.
|
||||
num_documents: An int representing the total number of documents
|
||||
|
||||
Returns:
|
||||
An array of "inverse document frequency" weights.
|
||||
"""
|
||||
return np.log(1 + num_documents / (1 + np.array(document_counts)))
|
||||
|
||||
def extract(self, accumulator):
|
||||
"""Convert an accumulator into a dict of output values.
|
||||
|
||||
Args:
|
||||
accumulator: An accumulator aggregating over the full dataset.
|
||||
|
||||
Returns:
|
||||
A dict of:
|
||||
"vocab": A list of the retained items in the vocabulary.
|
||||
"idf": The inverse-document-frequency for each item in vocab.
|
||||
idf[vocab_idx] is the IDF value for the corresponding vocab item.
|
||||
"oov_idf": The inverse-document-frequency for the OOV token.
|
||||
"""
|
||||
if self._compute_idf:
|
||||
vocab_counts, document_counts, num_documents = accumulator
|
||||
else:
|
||||
vocab_counts, _, _ = accumulator
|
||||
|
||||
sorted_counts = sorted(
|
||||
vocab_counts.items(), key=operator.itemgetter(1, 0), reverse=True)
|
||||
vocab_data = (
|
||||
sorted_counts[:self._vocab_size] if self._vocab_size else sorted_counts)
|
||||
vocab = [data[0] for data in vocab_data]
|
||||
|
||||
if self._compute_idf:
|
||||
doc_counts = [document_counts[token]["count"] for token in vocab]
|
||||
idf = self._inverse_document_frequency(doc_counts, num_documents[0])
|
||||
oov_idf = np.array([np.log(1 + num_documents[0])])
|
||||
return {_VOCAB_NAME: vocab, _IDF_NAME: idf, _OOV_IDF_NAME: oov_idf}
|
||||
else:
|
||||
return {_VOCAB_NAME: vocab}
|
||||
|
||||
def restore(self, output):
|
||||
"""Create an accumulator based on 'output'."""
|
||||
raise NotImplementedError(
|
||||
"TextVectorization does not restore or support streaming updates.")
|
||||
|
||||
def serialize(self, accumulator):
|
||||
"""Serialize an accumulator for a remote call."""
|
||||
output_dict = {}
|
||||
output_dict["metadata"] = accumulator.metadata
|
||||
output_dict["vocab"] = list(accumulator.count_dict.keys())
|
||||
output_dict["vocab_counts"] = list(accumulator.count_dict.values())
|
||||
if self._compute_idf:
|
||||
output_dict["idf_vocab"] = list(accumulator.per_doc_count_dict.keys())
|
||||
output_dict["idf_counts"] = [
|
||||
counter["count"]
|
||||
for counter in accumulator.per_doc_count_dict.values()
|
||||
]
|
||||
return compat.as_bytes(json.dumps(output_dict))
|
||||
|
||||
def deserialize(self, encoded_accumulator):
|
||||
"""Deserialize an accumulator received from 'serialize()'."""
|
||||
accumulator_dict = json.loads(compat.as_text(encoded_accumulator))
|
||||
|
||||
accumulator = self._create_accumulator()
|
||||
accumulator.metadata[0] = accumulator_dict["metadata"][0]
|
||||
|
||||
count_dict = dict(
|
||||
zip(accumulator_dict["vocab"], accumulator_dict["vocab_counts"]))
|
||||
accumulator.count_dict.update(count_dict)
|
||||
|
||||
if self._compute_idf:
|
||||
create_dict = lambda x: {"count": x, "last_doc_id": -1}
|
||||
idf_count_dicts = [
|
||||
create_dict(count) for count in accumulator_dict["idf_counts"]
|
||||
]
|
||||
idf_dict = dict(zip(accumulator_dict["idf_vocab"], idf_count_dicts))
|
||||
accumulator.per_doc_count_dict.update(idf_dict)
|
||||
|
||||
return accumulator
|
||||
|
||||
def _create_accumulator(self):
|
||||
"""Accumulate a sorted array of vocab tokens and corresponding counts."""
|
||||
|
||||
count_dict = collections.defaultdict(int)
|
||||
if self._compute_idf:
|
||||
create_default_dict = lambda: {"count": 0, "last_doc_id": -1}
|
||||
per_doc_count_dict = collections.defaultdict(create_default_dict)
|
||||
else:
|
||||
per_doc_count_dict = None
|
||||
metadata = [0]
|
||||
return _TextVectorizationAccumulator(count_dict, per_doc_count_dict,
|
||||
metadata)
|
||||
|
@ -62,7 +62,7 @@ def _get_end_to_end_test_cases():
|
||||
"testcase_name":
|
||||
"test_simple_tokens_int_mode",
|
||||
# Create an array where 'earth' is the most frequent term, followed by
|
||||
# 'wind', then 'and', then 'fire'. This ensures that the vocab accumulator
|
||||
# 'wind', then 'and', then 'fire'. This ensures that the vocab
|
||||
# is sorting by frequency.
|
||||
"vocab_data":
|
||||
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
||||
@ -78,6 +78,26 @@ def _get_end_to_end_test_cases():
|
||||
},
|
||||
"expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"test_simple_tokens_int_mode_hard_cap",
|
||||
# Create an array where 'earth' is the most frequent term, followed by
|
||||
# 'wind', then 'and', then 'fire'. This ensures that the vocab
|
||||
# is sorting by frequency.
|
||||
"vocab_data":
|
||||
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
||||
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
|
||||
"input_data":
|
||||
np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"],
|
||||
["and"], ["earth"], ["michigan"]]),
|
||||
"kwargs": {
|
||||
"max_tokens": 6,
|
||||
"standardize": None,
|
||||
"split": None,
|
||||
"output_mode": text_vectorization.INT
|
||||
},
|
||||
"expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"test_documents_int_mode",
|
||||
@ -985,7 +1005,7 @@ class TextVectorizationOutputTest(
|
||||
output_mode=text_vectorization.BINARY,
|
||||
pad_to_max_tokens=False)
|
||||
_ = layer(input_data)
|
||||
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"):
|
||||
layer.adapt(vocab_data)
|
||||
|
||||
def test_bag_output_soft_maximum_set_state_variables_after_call_fails(self):
|
||||
@ -1347,6 +1367,7 @@ class TextVectorizationErrorTest(keras_parameterized.TestCase,
|
||||
".*`output_sequence_length` must not be set.*"):
|
||||
_ = get_layer_class()(output_mode="count", output_sequence_length=2)
|
||||
|
||||
|
||||
# Custom functions for the custom callable serialization test. Declared here
|
||||
# to avoid multiple registrations from run_all_keras_modes().
|
||||
@generic_utils.register_keras_serializable(package="Test")
|
||||
@ -1528,208 +1549,5 @@ class TextVectorizationSavingTest(
|
||||
self.assertAllClose(new_output_dataset, expected_output)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class TextVectorizationCombinerTest(
|
||||
keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def compare_text_accumulators(self, a, b, msg=None):
|
||||
if a is None or b is None:
|
||||
self.assertAllEqual(a, b, msg=msg)
|
||||
|
||||
self.assertAllEqual(a.count_dict, b.count_dict, msg=msg)
|
||||
self.assertAllEqual(a.metadata, b.metadata, msg=msg)
|
||||
|
||||
if a.per_doc_count_dict is not None:
|
||||
|
||||
def per_doc_counts(accumulator):
|
||||
count_values = [
|
||||
count_dict["count"]
|
||||
for count_dict in accumulator.per_doc_count_dict.values()
|
||||
]
|
||||
return dict(zip(accumulator.per_doc_count_dict.keys(), count_values))
|
||||
|
||||
self.assertAllEqual(per_doc_counts(a), per_doc_counts(b), msg=msg)
|
||||
|
||||
compare_accumulators = compare_text_accumulators
|
||||
|
||||
def update_accumulator(self, accumulator, data):
|
||||
accumulator.count_dict.update(dict(zip(data["vocab"], data["counts"])))
|
||||
accumulator.metadata[0] = data["num_documents"]
|
||||
|
||||
if "document_counts" in data:
|
||||
create_dict = lambda x: {"count": x, "last_doc_id": -1}
|
||||
idf_count_dicts = [
|
||||
create_dict(count) for count in data["document_counts"]
|
||||
]
|
||||
idf_dict = dict(zip(data["vocab"], idf_count_dicts))
|
||||
|
||||
accumulator.per_doc_count_dict.update(idf_dict)
|
||||
|
||||
return accumulator
|
||||
|
||||
def test_combiner_api_compatibility_int_mode(self):
|
||||
data = np.array([["earth", "wind", "and", "fire"],
|
||||
["earth", "wind", "and", "michigan"]])
|
||||
combiner = text_vectorization._TextVectorizationCombiner(compute_idf=False)
|
||||
expected_accumulator_output = {
|
||||
"vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
|
||||
"counts": np.array([2, 2, 2, 1, 1]),
|
||||
"num_documents": np.array(2),
|
||||
}
|
||||
expected_extract_output = {
|
||||
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
|
||||
}
|
||||
expected_accumulator = combiner._create_accumulator()
|
||||
expected_accumulator = self.update_accumulator(expected_accumulator,
|
||||
expected_accumulator_output)
|
||||
self.validate_accumulator_serialize_and_deserialize(combiner, data,
|
||||
expected_accumulator)
|
||||
self.validate_accumulator_uniqueness(combiner, data)
|
||||
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
||||
|
||||
def test_combiner_api_compatibility_tfidf_mode(self):
|
||||
data = np.array([["earth", "wind", "and", "fire"],
|
||||
["earth", "wind", "and", "michigan"]])
|
||||
combiner = text_vectorization._TextVectorizationCombiner(compute_idf=True)
|
||||
expected_extract_output = {
|
||||
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
|
||||
"idf": np.array([0.510826, 0.510826, 0.510826, 0.693147, 0.693147]),
|
||||
"oov_idf": np.array([1.098612])
|
||||
}
|
||||
expected_accumulator_output = {
|
||||
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
|
||||
"counts": np.array([2, 2, 2, 1, 1]),
|
||||
"document_counts": np.array([2, 2, 2, 1, 1]),
|
||||
"num_documents": np.array(2),
|
||||
}
|
||||
|
||||
expected_accumulator = combiner._create_accumulator()
|
||||
expected_accumulator = self.update_accumulator(expected_accumulator,
|
||||
expected_accumulator_output)
|
||||
self.validate_accumulator_serialize_and_deserialize(combiner, data,
|
||||
expected_accumulator)
|
||||
self.validate_accumulator_uniqueness(combiner, data)
|
||||
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
||||
|
||||
# TODO(askerryryan): Add tests confirming equivalence to behavior of
|
||||
# existing tf.keras.preprocessing.text.Tokenizer.
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name":
|
||||
"top_k_smaller_than_full_vocab",
|
||||
"data":
|
||||
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
|
||||
["fire", "wind"]]),
|
||||
"vocab_size":
|
||||
3,
|
||||
"expected_accumulator_output": {
|
||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
||||
"counts": np.array([3, 2, 1, 1]),
|
||||
"document_counts": np.array([3, 2, 1, 1]),
|
||||
"num_documents": np.array(4),
|
||||
},
|
||||
"expected_extract_output": {
|
||||
"vocab": np.array(["wind", "fire", "earth"]),
|
||||
"idf": np.array([0.693147, 0.847298, 1.098612]),
|
||||
"oov_idf": np.array([1.609438]),
|
||||
},
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"top_k_larger_than_full_vocab",
|
||||
"data":
|
||||
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
|
||||
["fire", "wind"]]),
|
||||
"vocab_size":
|
||||
10,
|
||||
"expected_accumulator_output": {
|
||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
||||
"counts": np.array([3, 2, 1, 1]),
|
||||
"document_counts": np.array([3, 2, 1, 1]),
|
||||
"num_documents": np.array(4),
|
||||
},
|
||||
"expected_extract_output": {
|
||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
||||
"idf": np.array([0.693147, 0.847298, 1.098612, 1.098612]),
|
||||
"oov_idf": np.array([1.609438]),
|
||||
},
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"no_top_k",
|
||||
"data":
|
||||
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
|
||||
["fire", "wind"]]),
|
||||
"vocab_size":
|
||||
None,
|
||||
"expected_accumulator_output": {
|
||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
||||
"counts": np.array([3, 2, 1, 1]),
|
||||
"document_counts": np.array([3, 2, 1, 1]),
|
||||
"num_documents": np.array(4),
|
||||
},
|
||||
"expected_extract_output": {
|
||||
"vocab": np.array(["wind", "fire", "earth", "and"]),
|
||||
"idf": np.array([0.693147, 0.847298, 1.098612, 1.098612]),
|
||||
"oov_idf": np.array([1.609438]),
|
||||
},
|
||||
},
|
||||
{
|
||||
"testcase_name": "single_element_per_row",
|
||||
"data": np.array([["earth"], ["wind"], ["fire"], ["wind"], ["and"]]),
|
||||
"vocab_size": 3,
|
||||
"expected_accumulator_output": {
|
||||
"vocab": np.array(["wind", "and", "earth", "fire"]),
|
||||
"counts": np.array([2, 1, 1, 1]),
|
||||
"document_counts": np.array([2, 1, 1, 1]),
|
||||
"num_documents": np.array(5),
|
||||
},
|
||||
"expected_extract_output": {
|
||||
"vocab": np.array(["wind", "fire", "earth"]),
|
||||
"idf": np.array([0.980829, 1.252763, 1.252763]),
|
||||
"oov_idf": np.array([1.791759]),
|
||||
},
|
||||
},
|
||||
# Which tokens are retained are based on global frequency, and thus are
|
||||
# sensitive to frequency within a document. In contrast, because idf only
|
||||
# considers the presence of a token in a document, it is insensitive
|
||||
# to the frequency of the token within the document.
|
||||
{
|
||||
"testcase_name":
|
||||
"retained_tokens_sensitive_to_within_document_frequency",
|
||||
"data":
|
||||
np.array([["earth", "earth"], ["wind", "wind"], ["fire", "fire"],
|
||||
["wind", "wind"], ["and", "michigan"]]),
|
||||
"vocab_size":
|
||||
3,
|
||||
"expected_accumulator_output": {
|
||||
"vocab": np.array(["wind", "earth", "fire", "and", "michigan"]),
|
||||
"counts": np.array([4, 2, 2, 1, 1]),
|
||||
"document_counts": np.array([2, 1, 1, 1, 1]),
|
||||
"num_documents": np.array(5),
|
||||
},
|
||||
"expected_extract_output": {
|
||||
"vocab": np.array(["wind", "fire", "earth"]),
|
||||
"idf": np.array([0.980829, 1.252763, 1.252763]),
|
||||
"oov_idf": np.array([1.791759]),
|
||||
},
|
||||
})
|
||||
def test_combiner_computation(self,
|
||||
data,
|
||||
vocab_size,
|
||||
expected_accumulator_output,
|
||||
expected_extract_output,
|
||||
compute_idf=True):
|
||||
combiner = text_vectorization._TextVectorizationCombiner(
|
||||
vocab_size=vocab_size, compute_idf=compute_idf)
|
||||
expected_accumulator = combiner._create_accumulator()
|
||||
expected_accumulator = self.update_accumulator(expected_accumulator,
|
||||
expected_accumulator_output)
|
||||
self.validate_accumulator_computation(combiner, data, expected_accumulator)
|
||||
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user