Error out in lookup layers when invert is true and output_mode is not int

This output is meaningless for IntegerLookup, and will error out for StringLookup.
An error on init will be more friendly.

PiperOrigin-RevId: 361255701
Change-Id: I4e022938cd8cd696546dd275c83a97d64f944c21
This commit is contained in:
Matt Watson 2021-03-05 17:14:54 -08:00 committed by TensorFlower Gardener
parent 182c5c6f2a
commit b81c5f0ea5
2 changed files with 15 additions and 0 deletions

View File

@ -134,6 +134,10 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
layer_name=self.__class__.__name__,
arg_name="output_mode")
if invert and output_mode != INT:
raise ValueError("`output_mode` must be {} when `invert` is true. You "
"passed {}".format(INT, output_mode))
self.invert = invert
self.max_tokens = max_tokens
self.num_oov_indices = num_oov_indices

View File

@ -1431,6 +1431,17 @@ class IndexLookupInverseVocabularyTest(
dtype=dtypes.string,
invert=True)
def test_non_int_output_fails(self):
with self.assertRaisesRegex(ValueError, "`output_mode` must be int"):
_ = get_layer_class()(
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
dtype=dtypes.string,
output_mode=index_lookup.COUNT,
invert=True)
def test_vocab_with_repeated_element_fails(self):
vocab_data = ["earth", "earth", "wind", "and", "fire"]
layer = get_layer_class()(