From 39d3d60407da10ec3e25e1ae832b3325ac3feb6a Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 17 Nov 2020 16:38:29 -0800 Subject: [PATCH] Make sure CategoryEncoding raises an error with input values equal or greater to `max_tokens`. PiperOrigin-RevId: 342972294 Change-Id: If8f186776d784f22487ea179c8e2f6d8d962bbf3 --- .../python/keras/layers/preprocessing/category_encoding.py | 2 +- .../python/keras/layers/preprocessing/category_encoding_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding.py b/tensorflow/python/keras/layers/preprocessing/category_encoding.py index 7b57d1b33c3..53ab593bfda 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_encoding.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding.py @@ -304,7 +304,7 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): max_value = math_ops.reduce_max(inputs) min_value = math_ops.reduce_min(inputs) condition = math_ops.logical_and( - math_ops.greater_equal( + math_ops.greater( math_ops.cast(out_depth, max_value.dtype), max_value), math_ops.greater_equal( min_value, math_ops.cast(0, min_value.dtype))) diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py index 3b2026e5048..e5ba0bd1e17 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py @@ -269,7 +269,7 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase, _ = model.predict(input_array, steps=1) def test_dense_oov_input(self): - input_array = constant_op.constant([[1, 2, 3], [4, 3, 4]]) + input_array = constant_op.constant([[0, 1, 2], [2, 3, 1]]) max_tokens = 3 expected_output_shape = [None, max_tokens] encoder_layer = get_layer_class()(max_tokens)