Error out in lookup layers when invert is true and output_mode is not int
This output is meaningless for IntegerLookup, and will error out for StringLookup. An error on init will be more friendly. PiperOrigin-RevId: 361255701 Change-Id: I4e022938cd8cd696546dd275c83a97d64f944c21
This commit is contained in:
parent
182c5c6f2a
commit
b81c5f0ea5
@ -134,6 +134,10 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
layer_name=self.__class__.__name__,
|
||||
arg_name="output_mode")
|
||||
|
||||
if invert and output_mode != INT:
|
||||
raise ValueError("`output_mode` must be {} when `invert` is true. You "
|
||||
"passed {}".format(INT, output_mode))
|
||||
|
||||
self.invert = invert
|
||||
self.max_tokens = max_tokens
|
||||
self.num_oov_indices = num_oov_indices
|
||||
|
@ -1431,6 +1431,17 @@ class IndexLookupInverseVocabularyTest(
|
||||
dtype=dtypes.string,
|
||||
invert=True)
|
||||
|
||||
def test_non_int_output_fails(self):
|
||||
with self.assertRaisesRegex(ValueError, "`output_mode` must be int"):
|
||||
_ = get_layer_class()(
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
dtype=dtypes.string,
|
||||
output_mode=index_lookup.COUNT,
|
||||
invert=True)
|
||||
|
||||
def test_vocab_with_repeated_element_fails(self):
|
||||
vocab_data = ["earth", "earth", "wind", "and", "fire"]
|
||||
layer = get_layer_class()(
|
||||
|
Loading…
x
Reference in New Issue
Block a user