Make sure CategoryEncoding raises an error with input values equal or greater to max_tokens.

PiperOrigin-RevId: 342972294
Change-Id: If8f186776d784f22487ea179c8e2f6d8d962bbf3
This commit is contained in:
Francois Chollet 2020-11-17 16:38:29 -08:00 committed by TensorFlower Gardener
parent 086faf6609
commit 39d3d60407
2 changed files with 2 additions and 2 deletions

View File

@ -304,7 +304,7 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
max_value = math_ops.reduce_max(inputs)
min_value = math_ops.reduce_min(inputs)
condition = math_ops.logical_and(
math_ops.greater_equal(
math_ops.greater(
math_ops.cast(out_depth, max_value.dtype), max_value),
math_ops.greater_equal(
min_value, math_ops.cast(0, min_value.dtype)))

View File

@ -269,7 +269,7 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase,
_ = model.predict(input_array, steps=1)
def test_dense_oov_input(self):
input_array = constant_op.constant([[1, 2, 3], [4, 3, 4]])
input_array = constant_op.constant([[0, 1, 2], [2, 3, 1]])
max_tokens = 3
expected_output_shape = [None, max_tokens]
encoder_layer = get_layer_class()(max_tokens)