diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding.py b/tensorflow/python/keras/layers/preprocessing/category_encoding.py index 26c8d437c08..128188b09c2 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_encoding.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding.py @@ -267,12 +267,22 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): K.set_value(self.tf_idf_weights, tfidf_data) def call(self, inputs, count_weights=None): + if isinstance(inputs, (list, np.ndarray)): + inputs = ops.convert_to_tensor_v2(inputs) + if inputs.shape.rank == 1: + inputs = array_ops.expand_dims(inputs, 1) + if count_weights is not None and self._output_mode != COUNT: raise ValueError("count_weights is not used in `output_mode='tf-idf'`, " "or `output_mode='binary'`. Please pass a single input.") self._called = True if self._max_tokens is None: out_depth = K.get_value(self.num_elements) + if out_depth == 0: + raise RuntimeError( + "If you construct a `CategoryEncoding` layer with " + "`max_tokens=None`, you need to call `adapt()` " + "on it before using it") else: out_depth = self._max_tokens @@ -352,6 +362,8 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner): # TODO(momernick): Benchmark improvements to this algorithm. for element in values: + if not isinstance(element, list): + element = [element] current_doc_id = accumulator.data[self.DOC_ID_IDX] for value in element: current_max_value = accumulator.data[self.MAX_VALUE_IDX] diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py index 048ac3734af..ff1a06a3ae7 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py @@ -405,6 +405,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase, input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( max_tokens=None, output_mode=category_encoding.BINARY) + layer.adapt([1, 2]) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): layer.set_num_elements(5) @@ -415,6 +416,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase, input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( max_tokens=None, output_mode=category_encoding.BINARY) + layer.adapt(vocab_data) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "can't be adapted"): layer.adapt(vocab_data) @@ -425,6 +427,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase, input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( max_tokens=None, output_mode=category_encoding.BINARY) + layer.adapt([1, 2]) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): layer._set_state_variables(state_variables) @@ -741,6 +744,21 @@ class CategoryEncodingCombinerTest( self.validate_accumulator_computation(combiner, data, expected_accumulator) self.validate_accumulator_extract(combiner, data, expected_extract_output) + def test_1d_data(self): + data = [1, 2, 3] + cls = get_layer_class() + layer = cls() + layer.adapt(data) + output = layer(data) + self.assertListEqual(output.shape.as_list(), [3, 4]) + + def test_no_adapt_exception(self): + cls = get_layer_class() + layer = cls() + with self.assertRaisesRegex( + RuntimeError, r".*you need to call.*"): + _ = layer([1, 2, 3]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py index 88df3013257..992f47efc85 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py @@ -1037,6 +1037,7 @@ class TextVectorizationOutputTest( split=None, output_mode=text_vectorization.BINARY, pad_to_max_tokens=False) + layer.adapt(vocab_data) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"): layer.set_vocabulary(vocab_data) @@ -1054,6 +1055,7 @@ class TextVectorizationOutputTest( split=None, output_mode=text_vectorization.BINARY, pad_to_max_tokens=False) + layer.adapt(vocab_data) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"): layer.adapt(vocab_data) @@ -1070,6 +1072,7 @@ class TextVectorizationOutputTest( split=None, output_mode=text_vectorization.BINARY, pad_to_max_tokens=False) + layer.adapt(["earth", "wind"]) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"): layer._set_state_variables(state_variables)