Make CategoryEncoding work with 1D inputs and list inputs.

Also add error message when passing max_tokens=None and not calling adapt.

Note that the 1D fix won't apply to SparseTensor and RaggedTensor at this time.

PiperOrigin-RevId: 315803768
Change-Id: I7d302a2a9009ad63db3c5fb6a4209f63da8f2635
This commit is contained in:
Francois Chollet 2020-06-10 17:43:48 -07:00 committed by TensorFlower Gardener
parent 5d1368571f
commit a745f0a953
3 changed files with 33 additions and 0 deletions

View File

@ -267,12 +267,22 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
K.set_value(self.tf_idf_weights, tfidf_data) K.set_value(self.tf_idf_weights, tfidf_data)
def call(self, inputs, count_weights=None): def call(self, inputs, count_weights=None):
if isinstance(inputs, (list, np.ndarray)):
inputs = ops.convert_to_tensor_v2(inputs)
if inputs.shape.rank == 1:
inputs = array_ops.expand_dims(inputs, 1)
if count_weights is not None and self._output_mode != COUNT: if count_weights is not None and self._output_mode != COUNT:
raise ValueError("count_weights is not used in `output_mode='tf-idf'`, " raise ValueError("count_weights is not used in `output_mode='tf-idf'`, "
"or `output_mode='binary'`. Please pass a single input.") "or `output_mode='binary'`. Please pass a single input.")
self._called = True self._called = True
if self._max_tokens is None: if self._max_tokens is None:
out_depth = K.get_value(self.num_elements) out_depth = K.get_value(self.num_elements)
if out_depth == 0:
raise RuntimeError(
"If you construct a `CategoryEncoding` layer with "
"`max_tokens=None`, you need to call `adapt()` "
"on it before using it")
else: else:
out_depth = self._max_tokens out_depth = self._max_tokens
@ -352,6 +362,8 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner):
# TODO(momernick): Benchmark improvements to this algorithm. # TODO(momernick): Benchmark improvements to this algorithm.
for element in values: for element in values:
if not isinstance(element, list):
element = [element]
current_doc_id = accumulator.data[self.DOC_ID_IDX] current_doc_id = accumulator.data[self.DOC_ID_IDX]
for value in element: for value in element:
current_max_value = accumulator.data[self.MAX_VALUE_IDX] current_max_value = accumulator.data[self.MAX_VALUE_IDX]

View File

@ -405,6 +405,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=category_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
layer.adapt([1, 2])
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
layer.set_num_elements(5) layer.set_num_elements(5)
@ -415,6 +416,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=category_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
layer.adapt(vocab_data)
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "can't be adapted"): with self.assertRaisesRegex(RuntimeError, "can't be adapted"):
layer.adapt(vocab_data) layer.adapt(vocab_data)
@ -425,6 +427,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=category_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
layer.adapt([1, 2])
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
layer._set_state_variables(state_variables) layer._set_state_variables(state_variables)
@ -741,6 +744,21 @@ class CategoryEncodingCombinerTest(
self.validate_accumulator_computation(combiner, data, expected_accumulator) self.validate_accumulator_computation(combiner, data, expected_accumulator)
self.validate_accumulator_extract(combiner, data, expected_extract_output) self.validate_accumulator_extract(combiner, data, expected_extract_output)
def test_1d_data(self):
data = [1, 2, 3]
cls = get_layer_class()
layer = cls()
layer.adapt(data)
output = layer(data)
self.assertListEqual(output.shape.as_list(), [3, 4])
def test_no_adapt_exception(self):
cls = get_layer_class()
layer = cls()
with self.assertRaisesRegex(
RuntimeError, r".*you need to call.*"):
_ = layer([1, 2, 3])
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -1037,6 +1037,7 @@ class TextVectorizationOutputTest(
split=None, split=None,
output_mode=text_vectorization.BINARY, output_mode=text_vectorization.BINARY,
pad_to_max_tokens=False) pad_to_max_tokens=False)
layer.adapt(vocab_data)
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"): with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
layer.set_vocabulary(vocab_data) layer.set_vocabulary(vocab_data)
@ -1054,6 +1055,7 @@ class TextVectorizationOutputTest(
split=None, split=None,
output_mode=text_vectorization.BINARY, output_mode=text_vectorization.BINARY,
pad_to_max_tokens=False) pad_to_max_tokens=False)
layer.adapt(vocab_data)
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"): with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"):
layer.adapt(vocab_data) layer.adapt(vocab_data)
@ -1070,6 +1072,7 @@ class TextVectorizationOutputTest(
split=None, split=None,
output_mode=text_vectorization.BINARY, output_mode=text_vectorization.BINARY,
pad_to_max_tokens=False) pad_to_max_tokens=False)
layer.adapt(["earth", "wind"])
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"): with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
layer._set_state_variables(state_variables) layer._set_state_variables(state_variables)