Fix issues where index_lookup was improperly handling hard vocab caps. Add tests.

PiperOrigin-RevId: 312759072
Change-Id: Id24687eee01a6898473e128b8c2cfeb13be89547
This commit is contained in:
A. Unique TensorFlower 2020-05-21 16:04:08 -07:00 committed by TensorFlower Gardener
parent ed39014cf6
commit 19ed4a9ccf
3 changed files with 115 additions and 3 deletions

View File

@ -118,9 +118,14 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
else:
self._oov_value = -1
if max_tokens is not None:
num_mask_tokens = (0 if mask_token is None else 1)
vocab_size = max_tokens - (num_oov_indices + num_mask_tokens)
else:
vocab_size = None
super(IndexLookup, self).__init__(
combiner=_IndexLookupCombiner(self.max_tokens, self.mask_token),
**kwargs)
combiner=_IndexLookupCombiner(vocab_size, self.mask_token), **kwargs)
self._output_dtype = dtypes.int64

View File

@ -86,7 +86,8 @@ def _get_end_to_end_test_cases():
"vocab_data":
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
"input_data": np.array([[1], [2], [3], [4], [4], [3], [1], [5]]),
"input_data":
np.array([[1], [2], [3], [4], [4], [3], [1], [5]]),
"kwargs": {
"max_tokens": None,
"num_oov_indices": 1,
@ -125,6 +126,78 @@ def _get_end_to_end_test_cases():
"input_dtype":
dtypes.int64
},
{
"testcase_name":
"test_strings_hard_vocab_cap",
# Create an array where 'earth' is the most frequent term, followed by
# 'wind', then 'and', then 'fire'. This ensures that the vocab
# accumulator is sorting by frequency.
"vocab_data":
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
"input_data":
np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"],
["and"], ["earth"], ["michigan"]]),
"kwargs": {
"max_tokens": 5,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"dtype": dtypes.string,
},
"expected_output": [[2], [3], [4], [1], [1], [4], [2], [1]],
"input_dtype":
dtypes.string
},
{
"testcase_name":
"test_inverse_strings_hard_vocab_cap",
# Create an array where 'earth' is the most frequent term, followed by
# 'wind', then 'and', then 'fire'. This ensures that the vocab
# accumulator is sorting by frequency.
"vocab_data":
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
"input_data":
np.array([[1], [2], [3], [4], [4], [3], [1], [5]]),
"kwargs": {
"max_tokens": 5,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"dtype": dtypes.string,
"invert": True
},
"expected_output":
np.array([[b"earth"], [b"wind"], [b"and"], [b"[OOV]"], [b"[OOV]"],
[b"and"], [b"earth"], [b"[OOV]"]]),
"input_dtype":
dtypes.int64
},
{
"testcase_name":
"test_ints_hard_vocab_cap",
# Create an array where 1138 is the most frequent term, followed by
# 1729, then 725, then 42. This ensures that the vocab accumulator
# is sorting by frequency.
"vocab_data":
np.array([[42], [1138], [1138], [1138], [1138], [1729], [1729],
[1729], [725], [725]],
dtype=np.int64),
"input_data":
np.array([[1138], [1729], [725], [42], [42], [725], [1138], [4]],
dtype=np.int64),
"kwargs": {
"max_tokens": 5,
"num_oov_indices": 1,
"mask_token": 0,
"oov_token": -1,
"dtype": dtypes.int64,
},
"expected_output": [[2], [3], [4], [1], [1], [4], [2], [1]],
"input_dtype":
dtypes.int64
},
)
crossed_test_cases = []

View File

@ -1510,5 +1510,39 @@ class TextVectorizationSavingTest(
self.assertAllEqual(expected_output, new_output_dataset)
@keras_parameterized.run_all_keras_modes
class TextVectorizationE2ETest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_keras_vocab_trimming_example(self):
vocab_data = np.array([
"earth", "earth", "earth", "earth", "wind", "wind", "wind", "and",
"and", "fire"
])
input_array = np.array([["earth", "wind", "and", "earth"],
["ohio", "and", "earth", "michigan"]])
# pyformat: disable
expected_output = [[1, 2, 1],
[3, 1, 0]]
# pyformat: enable
max_tokens = 3
expected_output_shape = [None, max_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = get_layer_class()(
max_tokens=max_tokens,
standardize=None,
split=None,
output_mode=text_vectorization.COUNT,
pad_to_max_tokens=True)
int_data = layer(input_data)
layer.adapt(vocab_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
model = keras.Model(input_data, int_data)
output = model.predict(input_array)
self.assertAllEqual(expected_output, output)
if __name__ == "__main__":
test.main()