Error out in lookup layers when invert is true and output_mode is not int
This output is meaningless for IntegerLookup, and will error out for StringLookup. An error on init will be more friendly. PiperOrigin-RevId: 361255701 Change-Id: I4e022938cd8cd696546dd275c83a97d64f944c21
This commit is contained in:
parent
182c5c6f2a
commit
b81c5f0ea5
@ -134,6 +134,10 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
layer_name=self.__class__.__name__,
|
layer_name=self.__class__.__name__,
|
||||||
arg_name="output_mode")
|
arg_name="output_mode")
|
||||||
|
|
||||||
|
if invert and output_mode != INT:
|
||||||
|
raise ValueError("`output_mode` must be {} when `invert` is true. You "
|
||||||
|
"passed {}".format(INT, output_mode))
|
||||||
|
|
||||||
self.invert = invert
|
self.invert = invert
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.num_oov_indices = num_oov_indices
|
self.num_oov_indices = num_oov_indices
|
||||||
|
@ -1431,6 +1431,17 @@ class IndexLookupInverseVocabularyTest(
|
|||||||
dtype=dtypes.string,
|
dtype=dtypes.string,
|
||||||
invert=True)
|
invert=True)
|
||||||
|
|
||||||
|
def test_non_int_output_fails(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, "`output_mode` must be int"):
|
||||||
|
_ = get_layer_class()(
|
||||||
|
max_tokens=None,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token="",
|
||||||
|
oov_token="[OOV]",
|
||||||
|
dtype=dtypes.string,
|
||||||
|
output_mode=index_lookup.COUNT,
|
||||||
|
invert=True)
|
||||||
|
|
||||||
def test_vocab_with_repeated_element_fails(self):
|
def test_vocab_with_repeated_element_fails(self):
|
||||||
vocab_data = ["earth", "earth", "wind", "and", "fire"]
|
vocab_data = ["earth", "earth", "wind", "and", "fire"]
|
||||||
layer = get_layer_class()(
|
layer = get_layer_class()(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user