Update TextVectorization to use internal layer adapt calls instead of its own combiner.

PiperOrigin-RevId: 312292748
Change-Id: Ia157a06f55a28325dac9e4a58b3fed23fc4599d4
This commit is contained in:
A. Unique TensorFlower 2020-05-19 09:16:48 -07:00 committed by TensorFlower Gardener
parent 282db86128
commit 15d39f5e83
3 changed files with 46 additions and 418 deletions

View File

@ -55,8 +55,9 @@ class CombinerPreprocessingLayer(
def _get_dataset_iterator(self, dataset): def _get_dataset_iterator(self, dataset):
"""Gets an iterator from a tf.data.Dataset.""" """Gets an iterator from a tf.data.Dataset."""
iterator = dataset_ops.make_one_shot_iterator(dataset) iterator = dataset_ops.make_initializable_iterator(dataset)
session = K.get_session() session = K.get_session()
session.run(iterator.initializer)
next_element = iterator.get_next() next_element = iterator.get_next()
return lambda: session.run(next_element) return lambda: session.run(next_element)

View File

@ -17,10 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import json
import operator
import numpy as np import numpy as np
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
@ -29,7 +25,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
from tensorflow.python.keras.layers.preprocessing import categorical_encoding from tensorflow.python.keras.layers.preprocessing import categorical_encoding
from tensorflow.python.keras.layers.preprocessing import string_lookup from tensorflow.python.keras.layers.preprocessing import string_lookup
@ -41,7 +36,6 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_string_ops from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation" LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation"
@ -122,7 +116,9 @@ class TextVectorization(CombinerPreprocessingLayer):
Attributes: Attributes:
max_tokens: The maximum size of the vocabulary for this layer. If None, max_tokens: The maximum size of the vocabulary for this layer. If None,
there is no cap on the size of the vocabulary. there is no cap on the size of the vocabulary. Note that this vocabulary
contains 1 OOV token, so the effective number of tokens is `(max_tokens -
1 - (1 if output == "int" else 0))`
standardize: Optional specification for standardization to apply to the standardize: Optional specification for standardization to apply to the
input text. Values can be None (no standardization), input text. Values can be None (no standardization),
'lower_and_strip_punctuation' (lowercase and remove punctuation) or a 'lower_and_strip_punctuation' (lowercase and remove punctuation) or a
@ -138,7 +134,8 @@ class TextVectorization(CombinerPreprocessingLayer):
output_mode: Optional specification for the output of the layer. Values can output_mode: Optional specification for the output of the layer. Values can
be "int", "binary", "count" or "tf-idf", configuring the layer as follows: be "int", "binary", "count" or "tf-idf", configuring the layer as follows:
"int": Outputs integer indices, one integer index per split string "int": Outputs integer indices, one integer index per split string
token. token. When output == "int", 0 is reserved for masked locations;
this reduces the vocab size to max_tokens-2 instead of max_tokens-1
"binary": Outputs a single int array per batch, of either vocab_size or "binary": Outputs a single int array per batch, of either vocab_size or
max_tokens size, containing 1s in all elements where the token mapped max_tokens size, containing 1s in all elements where the token mapped
to that index exists at least once in the batch item. to that index exists at least once in the batch item.
@ -274,12 +271,6 @@ class TextVectorization(CombinerPreprocessingLayer):
# the OOV value to zero instead of one. # the OOV value to zero instead of one.
self._oov_value = 1 if output_mode == INT else 0 self._oov_value = 1 if output_mode == INT else 0
# We always reduce the max token number by 1 to account for the OOV token
# if it is set. Keras' use of the reserved number 0 for padding tokens,
# if the output is in INT mode, does not really count as a 'token' for
# vocabulary purposes, so we only reduce vocab size by 1 here.
self._max_vocab_size = max_tokens - 1 if max_tokens is not None else None
self._standardize = standardize self._standardize = standardize
self._split = split self._split = split
self._ngrams_arg = ngrams self._ngrams_arg = ngrams
@ -295,8 +286,7 @@ class TextVectorization(CombinerPreprocessingLayer):
self._called = False self._called = False
super(TextVectorization, self).__init__( super(TextVectorization, self).__init__(
combiner=_TextVectorizationCombiner( combiner=None,
self._max_vocab_size, compute_idf=output_mode == TFIDF),
**kwargs) **kwargs)
mask_token = "" if output_mode in [None, INT] else None mask_token = "" if output_mode in [None, INT] else None
@ -306,14 +296,14 @@ class TextVectorization(CombinerPreprocessingLayer):
# If this layer is configured for string or integer output, we do not # If this layer is configured for string or integer output, we do not
# create a vectorization layer (as the output is not vectorized). # create a vectorization layer (as the output is not vectorized).
if self._output_mode in [None, INT]: if self._output_mode in [None, INT]:
return self._vectorize_layer = None
if max_tokens is not None and self._pad_to_max:
max_elements = max_tokens
else: else:
max_elements = None if max_tokens is not None and self._pad_to_max:
self._vectorize_layer = self._get_vectorization_class()( max_elements = max_tokens
max_tokens=max_elements, output_mode=self._output_mode) else:
max_elements = None
self._vectorize_layer = self._get_vectorization_class()(
max_tokens=max_elements, output_mode=self._output_mode)
# These are V1/V2 shim points. There are V1 implementations in the V1 class. # These are V1/V2 shim points. There are V1 implementations in the V1 class.
def _get_vectorization_class(self): def _get_vectorization_class(self):
@ -407,7 +397,14 @@ class TextVectorization(CombinerPreprocessingLayer):
raise ValueError( raise ValueError(
"adapt() requires a Dataset or an array as input, got {}".format( "adapt() requires a Dataset or an array as input, got {}".format(
type(data))) type(data)))
super(TextVectorization, self).adapt(preprocessed_inputs, reset_state)
self._index_lookup_layer.adapt(preprocessed_inputs)
if self._vectorize_layer:
if isinstance(data, ops.Tensor):
integer_data = self._index_lookup_layer(preprocessed_inputs)
else:
integer_data = preprocessed_inputs.map(self._index_lookup_layer)
self._vectorize_layer.adapt(integer_data)
def get_vocabulary(self): def get_vocabulary(self):
return self._index_lookup_layer.get_vocabulary() return self._index_lookup_layer.get_vocabulary()
@ -616,191 +613,3 @@ class TextVectorization(CombinerPreprocessingLayer):
# If we're not returning integers here, we rely on the vectorization layer # If we're not returning integers here, we rely on the vectorization layer
# to create the output. # to create the output.
return self._vectorize_layer(indexed_data) return self._vectorize_layer(indexed_data)
class _TextVectorizationAccumulator(
collections.namedtuple("_TextVectorizationAccumulator",
["count_dict", "per_doc_count_dict", "metadata"])):
pass
# A note on this combiner: This contains functionality that will be extracted
# into the Vectorization and IndexLookup combiner objects. At that point,
# TextVectorization can become a PreprocessingStage instead of a Layer and
# this combiner can be retired. Until then, we leave this as is instead of
# attempting a refactor of what will soon be deleted.
class _TextVectorizationCombiner(Combiner):
"""Combiner for the TextVectorization preprocessing layer.
This class encapsulates the logic for computing a vocabulary based on the
frequency of each token.
Attributes:
vocab_size: (Optional) If set, only the top `vocab_size` tokens (based on
frequency across the dataset) are retained in the vocabulary. If None, or
set to a value greater than the total number of distinct tokens in the
dataset, all tokens are retained.
compute_idf: (Optional) If set, the inverse document frequency will be
computed for each value.
"""
def __init__(self, vocab_size=None, compute_idf=False):
self._vocab_size = vocab_size
self._compute_idf = compute_idf
self._input_dtype = dtypes.string
def compute(self, values, accumulator=None):
"""Compute a step in this computation, returning a new accumulator."""
if dtypes.as_dtype(self._input_dtype) != dtypes.as_dtype(values.dtype):
raise RuntimeError("Expected input type %s, got %s" %
(self._input_dtype, values.dtype))
if ragged_tensor.is_ragged(values):
values = values.to_list()
if isinstance(values, ops.EagerTensor):
values = values.numpy()
if isinstance(values, np.ndarray):
values = values.tolist()
if accumulator is None:
accumulator = self._create_accumulator()
# If we are being passed raw strings or bytestrings, we need to wrap them
# in an array so we don't accidentally iterate over the bytes instead of
# treating the string as one object.
if isinstance(values, (str, bytes)):
values = [values]
# TODO(momernick): Benchmark improvements to this algorithm.
for document in values:
current_doc_id = accumulator.metadata[0]
for token in document:
accumulator.count_dict[token] += 1
if self._compute_idf:
doc_count = accumulator.per_doc_count_dict[token]
if doc_count["last_doc_id"] != current_doc_id:
doc_count["count"] += 1
doc_count["last_doc_id"] = current_doc_id
accumulator.metadata[0] += 1
return accumulator
def merge(self, accumulators):
"""Merge several accumulators to a single accumulator."""
if not accumulators:
return accumulators
base_accumulator = accumulators[0]
for accumulator in accumulators[1:]:
base_accumulator.metadata[0] += accumulator.metadata[0]
for token, value in accumulator.count_dict.items():
base_accumulator.count_dict[token] += value
if self._compute_idf:
for token, value in accumulator.per_doc_count_dict.items():
# Any newly created token counts in 'base_accumulator''s
# per_doc_count_dict will have a last_doc_id of -1. This is always
# less than the next doc id (which are strictly positive), so any
# future occurrences are guaranteed to be counted.
base_accumulator.per_doc_count_dict[token]["count"] += value["count"]
return base_accumulator
def _inverse_document_frequency(self, document_counts, num_documents):
"""Compute the inverse-document-frequency (IDF) component of TFIDF.
Uses the default weighting scheme described in
https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
Args:
document_counts: An array of the # of documents each token appears in.
num_documents: An int representing the total number of documents
Returns:
An array of "inverse document frequency" weights.
"""
return np.log(1 + num_documents / (1 + np.array(document_counts)))
def extract(self, accumulator):
"""Convert an accumulator into a dict of output values.
Args:
accumulator: An accumulator aggregating over the full dataset.
Returns:
A dict of:
"vocab": A list of the retained items in the vocabulary.
"idf": The inverse-document-frequency for each item in vocab.
idf[vocab_idx] is the IDF value for the corresponding vocab item.
"oov_idf": The inverse-document-frequency for the OOV token.
"""
if self._compute_idf:
vocab_counts, document_counts, num_documents = accumulator
else:
vocab_counts, _, _ = accumulator
sorted_counts = sorted(
vocab_counts.items(), key=operator.itemgetter(1, 0), reverse=True)
vocab_data = (
sorted_counts[:self._vocab_size] if self._vocab_size else sorted_counts)
vocab = [data[0] for data in vocab_data]
if self._compute_idf:
doc_counts = [document_counts[token]["count"] for token in vocab]
idf = self._inverse_document_frequency(doc_counts, num_documents[0])
oov_idf = np.array([np.log(1 + num_documents[0])])
return {_VOCAB_NAME: vocab, _IDF_NAME: idf, _OOV_IDF_NAME: oov_idf}
else:
return {_VOCAB_NAME: vocab}
def restore(self, output):
"""Create an accumulator based on 'output'."""
raise NotImplementedError(
"TextVectorization does not restore or support streaming updates.")
def serialize(self, accumulator):
"""Serialize an accumulator for a remote call."""
output_dict = {}
output_dict["metadata"] = accumulator.metadata
output_dict["vocab"] = list(accumulator.count_dict.keys())
output_dict["vocab_counts"] = list(accumulator.count_dict.values())
if self._compute_idf:
output_dict["idf_vocab"] = list(accumulator.per_doc_count_dict.keys())
output_dict["idf_counts"] = [
counter["count"]
for counter in accumulator.per_doc_count_dict.values()
]
return compat.as_bytes(json.dumps(output_dict))
def deserialize(self, encoded_accumulator):
"""Deserialize an accumulator received from 'serialize()'."""
accumulator_dict = json.loads(compat.as_text(encoded_accumulator))
accumulator = self._create_accumulator()
accumulator.metadata[0] = accumulator_dict["metadata"][0]
count_dict = dict(
zip(accumulator_dict["vocab"], accumulator_dict["vocab_counts"]))
accumulator.count_dict.update(count_dict)
if self._compute_idf:
create_dict = lambda x: {"count": x, "last_doc_id": -1}
idf_count_dicts = [
create_dict(count) for count in accumulator_dict["idf_counts"]
]
idf_dict = dict(zip(accumulator_dict["idf_vocab"], idf_count_dicts))
accumulator.per_doc_count_dict.update(idf_dict)
return accumulator
def _create_accumulator(self):
"""Accumulate a sorted array of vocab tokens and corresponding counts."""
count_dict = collections.defaultdict(int)
if self._compute_idf:
create_default_dict = lambda: {"count": 0, "last_doc_id": -1}
per_doc_count_dict = collections.defaultdict(create_default_dict)
else:
per_doc_count_dict = None
metadata = [0]
return _TextVectorizationAccumulator(count_dict, per_doc_count_dict,
metadata)

View File

@ -62,7 +62,7 @@ def _get_end_to_end_test_cases():
"testcase_name": "testcase_name":
"test_simple_tokens_int_mode", "test_simple_tokens_int_mode",
# Create an array where 'earth' is the most frequent term, followed by # Create an array where 'earth' is the most frequent term, followed by
# 'wind', then 'and', then 'fire'. This ensures that the vocab accumulator # 'wind', then 'and', then 'fire'. This ensures that the vocab
# is sorting by frequency. # is sorting by frequency.
"vocab_data": "vocab_data":
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
@ -78,6 +78,26 @@ def _get_end_to_end_test_cases():
}, },
"expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]], "expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
}, },
{
"testcase_name":
"test_simple_tokens_int_mode_hard_cap",
# Create an array where 'earth' is the most frequent term, followed by
# 'wind', then 'and', then 'fire'. This ensures that the vocab
# is sorting by frequency.
"vocab_data":
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
"input_data":
np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"],
["and"], ["earth"], ["michigan"]]),
"kwargs": {
"max_tokens": 6,
"standardize": None,
"split": None,
"output_mode": text_vectorization.INT
},
"expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
},
{ {
"testcase_name": "testcase_name":
"test_documents_int_mode", "test_documents_int_mode",
@ -985,7 +1005,7 @@ class TextVectorizationOutputTest(
output_mode=text_vectorization.BINARY, output_mode=text_vectorization.BINARY,
pad_to_max_tokens=False) pad_to_max_tokens=False)
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"): with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"):
layer.adapt(vocab_data) layer.adapt(vocab_data)
def test_bag_output_soft_maximum_set_state_variables_after_call_fails(self): def test_bag_output_soft_maximum_set_state_variables_after_call_fails(self):
@ -1347,6 +1367,7 @@ class TextVectorizationErrorTest(keras_parameterized.TestCase,
".*`output_sequence_length` must not be set.*"): ".*`output_sequence_length` must not be set.*"):
_ = get_layer_class()(output_mode="count", output_sequence_length=2) _ = get_layer_class()(output_mode="count", output_sequence_length=2)
# Custom functions for the custom callable serialization test. Declared here # Custom functions for the custom callable serialization test. Declared here
# to avoid multiple registrations from run_all_keras_modes(). # to avoid multiple registrations from run_all_keras_modes().
@generic_utils.register_keras_serializable(package="Test") @generic_utils.register_keras_serializable(package="Test")
@ -1528,208 +1549,5 @@ class TextVectorizationSavingTest(
self.assertAllClose(new_output_dataset, expected_output) self.assertAllClose(new_output_dataset, expected_output)
@keras_parameterized.run_all_keras_modes
class TextVectorizationCombinerTest(
keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def compare_text_accumulators(self, a, b, msg=None):
if a is None or b is None:
self.assertAllEqual(a, b, msg=msg)
self.assertAllEqual(a.count_dict, b.count_dict, msg=msg)
self.assertAllEqual(a.metadata, b.metadata, msg=msg)
if a.per_doc_count_dict is not None:
def per_doc_counts(accumulator):
count_values = [
count_dict["count"]
for count_dict in accumulator.per_doc_count_dict.values()
]
return dict(zip(accumulator.per_doc_count_dict.keys(), count_values))
self.assertAllEqual(per_doc_counts(a), per_doc_counts(b), msg=msg)
compare_accumulators = compare_text_accumulators
def update_accumulator(self, accumulator, data):
accumulator.count_dict.update(dict(zip(data["vocab"], data["counts"])))
accumulator.metadata[0] = data["num_documents"]
if "document_counts" in data:
create_dict = lambda x: {"count": x, "last_doc_id": -1}
idf_count_dicts = [
create_dict(count) for count in data["document_counts"]
]
idf_dict = dict(zip(data["vocab"], idf_count_dicts))
accumulator.per_doc_count_dict.update(idf_dict)
return accumulator
def test_combiner_api_compatibility_int_mode(self):
data = np.array([["earth", "wind", "and", "fire"],
["earth", "wind", "and", "michigan"]])
combiner = text_vectorization._TextVectorizationCombiner(compute_idf=False)
expected_accumulator_output = {
"vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
"counts": np.array([2, 2, 2, 1, 1]),
"num_documents": np.array(2),
}
expected_extract_output = {
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
}
expected_accumulator = combiner._create_accumulator()
expected_accumulator = self.update_accumulator(expected_accumulator,
expected_accumulator_output)
self.validate_accumulator_serialize_and_deserialize(combiner, data,
expected_accumulator)
self.validate_accumulator_uniqueness(combiner, data)
self.validate_accumulator_extract(combiner, data, expected_extract_output)
def test_combiner_api_compatibility_tfidf_mode(self):
data = np.array([["earth", "wind", "and", "fire"],
["earth", "wind", "and", "michigan"]])
combiner = text_vectorization._TextVectorizationCombiner(compute_idf=True)
expected_extract_output = {
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
"idf": np.array([0.510826, 0.510826, 0.510826, 0.693147, 0.693147]),
"oov_idf": np.array([1.098612])
}
expected_accumulator_output = {
"vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
"counts": np.array([2, 2, 2, 1, 1]),
"document_counts": np.array([2, 2, 2, 1, 1]),
"num_documents": np.array(2),
}
expected_accumulator = combiner._create_accumulator()
expected_accumulator = self.update_accumulator(expected_accumulator,
expected_accumulator_output)
self.validate_accumulator_serialize_and_deserialize(combiner, data,
expected_accumulator)
self.validate_accumulator_uniqueness(combiner, data)
self.validate_accumulator_extract(combiner, data, expected_extract_output)
# TODO(askerryryan): Add tests confirming equivalence to behavior of
# existing tf.keras.preprocessing.text.Tokenizer.
@parameterized.named_parameters(
{
"testcase_name":
"top_k_smaller_than_full_vocab",
"data":
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
["fire", "wind"]]),
"vocab_size":
3,
"expected_accumulator_output": {
"vocab": np.array(["wind", "fire", "earth", "and"]),
"counts": np.array([3, 2, 1, 1]),
"document_counts": np.array([3, 2, 1, 1]),
"num_documents": np.array(4),
},
"expected_extract_output": {
"vocab": np.array(["wind", "fire", "earth"]),
"idf": np.array([0.693147, 0.847298, 1.098612]),
"oov_idf": np.array([1.609438]),
},
},
{
"testcase_name":
"top_k_larger_than_full_vocab",
"data":
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
["fire", "wind"]]),
"vocab_size":
10,
"expected_accumulator_output": {
"vocab": np.array(["wind", "fire", "earth", "and"]),
"counts": np.array([3, 2, 1, 1]),
"document_counts": np.array([3, 2, 1, 1]),
"num_documents": np.array(4),
},
"expected_extract_output": {
"vocab": np.array(["wind", "fire", "earth", "and"]),
"idf": np.array([0.693147, 0.847298, 1.098612, 1.098612]),
"oov_idf": np.array([1.609438]),
},
},
{
"testcase_name":
"no_top_k",
"data":
np.array([["earth", "wind"], ["fire", "wind"], ["and"],
["fire", "wind"]]),
"vocab_size":
None,
"expected_accumulator_output": {
"vocab": np.array(["wind", "fire", "earth", "and"]),
"counts": np.array([3, 2, 1, 1]),
"document_counts": np.array([3, 2, 1, 1]),
"num_documents": np.array(4),
},
"expected_extract_output": {
"vocab": np.array(["wind", "fire", "earth", "and"]),
"idf": np.array([0.693147, 0.847298, 1.098612, 1.098612]),
"oov_idf": np.array([1.609438]),
},
},
{
"testcase_name": "single_element_per_row",
"data": np.array([["earth"], ["wind"], ["fire"], ["wind"], ["and"]]),
"vocab_size": 3,
"expected_accumulator_output": {
"vocab": np.array(["wind", "and", "earth", "fire"]),
"counts": np.array([2, 1, 1, 1]),
"document_counts": np.array([2, 1, 1, 1]),
"num_documents": np.array(5),
},
"expected_extract_output": {
"vocab": np.array(["wind", "fire", "earth"]),
"idf": np.array([0.980829, 1.252763, 1.252763]),
"oov_idf": np.array([1.791759]),
},
},
# Which tokens are retained are based on global frequency, and thus are
# sensitive to frequency within a document. In contrast, because idf only
# considers the presence of a token in a document, it is insensitive
# to the frequency of the token within the document.
{
"testcase_name":
"retained_tokens_sensitive_to_within_document_frequency",
"data":
np.array([["earth", "earth"], ["wind", "wind"], ["fire", "fire"],
["wind", "wind"], ["and", "michigan"]]),
"vocab_size":
3,
"expected_accumulator_output": {
"vocab": np.array(["wind", "earth", "fire", "and", "michigan"]),
"counts": np.array([4, 2, 2, 1, 1]),
"document_counts": np.array([2, 1, 1, 1, 1]),
"num_documents": np.array(5),
},
"expected_extract_output": {
"vocab": np.array(["wind", "fire", "earth"]),
"idf": np.array([0.980829, 1.252763, 1.252763]),
"oov_idf": np.array([1.791759]),
},
})
def test_combiner_computation(self,
data,
vocab_size,
expected_accumulator_output,
expected_extract_output,
compute_idf=True):
combiner = text_vectorization._TextVectorizationCombiner(
vocab_size=vocab_size, compute_idf=compute_idf)
expected_accumulator = combiner._create_accumulator()
expected_accumulator = self.update_accumulator(expected_accumulator,
expected_accumulator_output)
self.validate_accumulator_computation(combiner, data, expected_accumulator)
self.validate_accumulator_extract(combiner, data, expected_extract_output)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()