Fix issues where index_lookup was improperly handling hard vocab caps. Add tests.
PiperOrigin-RevId: 312759072 Change-Id: Id24687eee01a6898473e128b8c2cfeb13be89547
This commit is contained in:
parent
ed39014cf6
commit
19ed4a9ccf
@ -118,9 +118,14 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
else:
|
||||
self._oov_value = -1
|
||||
|
||||
if max_tokens is not None:
|
||||
num_mask_tokens = (0 if mask_token is None else 1)
|
||||
vocab_size = max_tokens - (num_oov_indices + num_mask_tokens)
|
||||
else:
|
||||
vocab_size = None
|
||||
|
||||
super(IndexLookup, self).__init__(
|
||||
combiner=_IndexLookupCombiner(self.max_tokens, self.mask_token),
|
||||
**kwargs)
|
||||
combiner=_IndexLookupCombiner(vocab_size, self.mask_token), **kwargs)
|
||||
|
||||
self._output_dtype = dtypes.int64
|
||||
|
||||
|
@ -86,7 +86,8 @@ def _get_end_to_end_test_cases():
|
||||
"vocab_data":
|
||||
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
||||
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
|
||||
"input_data": np.array([[1], [2], [3], [4], [4], [3], [1], [5]]),
|
||||
"input_data":
|
||||
np.array([[1], [2], [3], [4], [4], [3], [1], [5]]),
|
||||
"kwargs": {
|
||||
"max_tokens": None,
|
||||
"num_oov_indices": 1,
|
||||
@ -125,6 +126,78 @@ def _get_end_to_end_test_cases():
|
||||
"input_dtype":
|
||||
dtypes.int64
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"test_strings_hard_vocab_cap",
|
||||
# Create an array where 'earth' is the most frequent term, followed by
|
||||
# 'wind', then 'and', then 'fire'. This ensures that the vocab
|
||||
# accumulator is sorting by frequency.
|
||||
"vocab_data":
|
||||
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
||||
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
|
||||
"input_data":
|
||||
np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"],
|
||||
["and"], ["earth"], ["michigan"]]),
|
||||
"kwargs": {
|
||||
"max_tokens": 5,
|
||||
"num_oov_indices": 1,
|
||||
"mask_token": "",
|
||||
"oov_token": "[OOV]",
|
||||
"dtype": dtypes.string,
|
||||
},
|
||||
"expected_output": [[2], [3], [4], [1], [1], [4], [2], [1]],
|
||||
"input_dtype":
|
||||
dtypes.string
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"test_inverse_strings_hard_vocab_cap",
|
||||
# Create an array where 'earth' is the most frequent term, followed by
|
||||
# 'wind', then 'and', then 'fire'. This ensures that the vocab
|
||||
# accumulator is sorting by frequency.
|
||||
"vocab_data":
|
||||
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
||||
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
|
||||
"input_data":
|
||||
np.array([[1], [2], [3], [4], [4], [3], [1], [5]]),
|
||||
"kwargs": {
|
||||
"max_tokens": 5,
|
||||
"num_oov_indices": 1,
|
||||
"mask_token": "",
|
||||
"oov_token": "[OOV]",
|
||||
"dtype": dtypes.string,
|
||||
"invert": True
|
||||
},
|
||||
"expected_output":
|
||||
np.array([[b"earth"], [b"wind"], [b"and"], [b"[OOV]"], [b"[OOV]"],
|
||||
[b"and"], [b"earth"], [b"[OOV]"]]),
|
||||
"input_dtype":
|
||||
dtypes.int64
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"test_ints_hard_vocab_cap",
|
||||
# Create an array where 1138 is the most frequent term, followed by
|
||||
# 1729, then 725, then 42. This ensures that the vocab accumulator
|
||||
# is sorting by frequency.
|
||||
"vocab_data":
|
||||
np.array([[42], [1138], [1138], [1138], [1138], [1729], [1729],
|
||||
[1729], [725], [725]],
|
||||
dtype=np.int64),
|
||||
"input_data":
|
||||
np.array([[1138], [1729], [725], [42], [42], [725], [1138], [4]],
|
||||
dtype=np.int64),
|
||||
"kwargs": {
|
||||
"max_tokens": 5,
|
||||
"num_oov_indices": 1,
|
||||
"mask_token": 0,
|
||||
"oov_token": -1,
|
||||
"dtype": dtypes.int64,
|
||||
},
|
||||
"expected_output": [[2], [3], [4], [1], [1], [4], [2], [1]],
|
||||
"input_dtype":
|
||||
dtypes.int64
|
||||
},
|
||||
)
|
||||
|
||||
crossed_test_cases = []
|
||||
|
@ -1510,5 +1510,39 @@ class TextVectorizationSavingTest(
|
||||
self.assertAllEqual(expected_output, new_output_dataset)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class TextVectorizationE2ETest(keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def test_keras_vocab_trimming_example(self):
|
||||
vocab_data = np.array([
|
||||
"earth", "earth", "earth", "earth", "wind", "wind", "wind", "and",
|
||||
"and", "fire"
|
||||
])
|
||||
input_array = np.array([["earth", "wind", "and", "earth"],
|
||||
["ohio", "and", "earth", "michigan"]])
|
||||
|
||||
# pyformat: disable
|
||||
expected_output = [[1, 2, 1],
|
||||
[3, 1, 0]]
|
||||
# pyformat: enable
|
||||
max_tokens = 3
|
||||
expected_output_shape = [None, max_tokens]
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=max_tokens,
|
||||
standardize=None,
|
||||
split=None,
|
||||
output_mode=text_vectorization.COUNT,
|
||||
pad_to_max_tokens=True)
|
||||
int_data = layer(input_data)
|
||||
layer.adapt(vocab_data)
|
||||
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
|
||||
model = keras.Model(input_data, int_data)
|
||||
output = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user