Make CategoryEncoding work with 1D inputs and list inputs.
Also add error message when passing max_tokens=None and not calling adapt. Note that the 1D fix won't apply to SparseTensor and RaggedTensor at this time. PiperOrigin-RevId: 315803768 Change-Id: I7d302a2a9009ad63db3c5fb6a4209f63da8f2635
This commit is contained in:
parent
5d1368571f
commit
a745f0a953
|
@ -267,12 +267,22 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||||
K.set_value(self.tf_idf_weights, tfidf_data)
|
K.set_value(self.tf_idf_weights, tfidf_data)
|
||||||
|
|
||||||
def call(self, inputs, count_weights=None):
|
def call(self, inputs, count_weights=None):
|
||||||
|
if isinstance(inputs, (list, np.ndarray)):
|
||||||
|
inputs = ops.convert_to_tensor_v2(inputs)
|
||||||
|
if inputs.shape.rank == 1:
|
||||||
|
inputs = array_ops.expand_dims(inputs, 1)
|
||||||
|
|
||||||
if count_weights is not None and self._output_mode != COUNT:
|
if count_weights is not None and self._output_mode != COUNT:
|
||||||
raise ValueError("count_weights is not used in `output_mode='tf-idf'`, "
|
raise ValueError("count_weights is not used in `output_mode='tf-idf'`, "
|
||||||
"or `output_mode='binary'`. Please pass a single input.")
|
"or `output_mode='binary'`. Please pass a single input.")
|
||||||
self._called = True
|
self._called = True
|
||||||
if self._max_tokens is None:
|
if self._max_tokens is None:
|
||||||
out_depth = K.get_value(self.num_elements)
|
out_depth = K.get_value(self.num_elements)
|
||||||
|
if out_depth == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"If you construct a `CategoryEncoding` layer with "
|
||||||
|
"`max_tokens=None`, you need to call `adapt()` "
|
||||||
|
"on it before using it")
|
||||||
else:
|
else:
|
||||||
out_depth = self._max_tokens
|
out_depth = self._max_tokens
|
||||||
|
|
||||||
|
@ -352,6 +362,8 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner):
|
||||||
|
|
||||||
# TODO(momernick): Benchmark improvements to this algorithm.
|
# TODO(momernick): Benchmark improvements to this algorithm.
|
||||||
for element in values:
|
for element in values:
|
||||||
|
if not isinstance(element, list):
|
||||||
|
element = [element]
|
||||||
current_doc_id = accumulator.data[self.DOC_ID_IDX]
|
current_doc_id = accumulator.data[self.DOC_ID_IDX]
|
||||||
for value in element:
|
for value in element:
|
||||||
current_max_value = accumulator.data[self.MAX_VALUE_IDX]
|
current_max_value = accumulator.data[self.MAX_VALUE_IDX]
|
||||||
|
|
|
@ -405,6 +405,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
|
||||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
||||||
layer = get_layer_class()(
|
layer = get_layer_class()(
|
||||||
max_tokens=None, output_mode=category_encoding.BINARY)
|
max_tokens=None, output_mode=category_encoding.BINARY)
|
||||||
|
layer.adapt([1, 2])
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
|
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
|
||||||
layer.set_num_elements(5)
|
layer.set_num_elements(5)
|
||||||
|
@ -415,6 +416,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
|
||||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
||||||
layer = get_layer_class()(
|
layer = get_layer_class()(
|
||||||
max_tokens=None, output_mode=category_encoding.BINARY)
|
max_tokens=None, output_mode=category_encoding.BINARY)
|
||||||
|
layer.adapt(vocab_data)
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
with self.assertRaisesRegex(RuntimeError, "can't be adapted"):
|
with self.assertRaisesRegex(RuntimeError, "can't be adapted"):
|
||||||
layer.adapt(vocab_data)
|
layer.adapt(vocab_data)
|
||||||
|
@ -425,6 +427,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
|
||||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
||||||
layer = get_layer_class()(
|
layer = get_layer_class()(
|
||||||
max_tokens=None, output_mode=category_encoding.BINARY)
|
max_tokens=None, output_mode=category_encoding.BINARY)
|
||||||
|
layer.adapt([1, 2])
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
|
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
|
||||||
layer._set_state_variables(state_variables)
|
layer._set_state_variables(state_variables)
|
||||||
|
@ -741,6 +744,21 @@ class CategoryEncodingCombinerTest(
|
||||||
self.validate_accumulator_computation(combiner, data, expected_accumulator)
|
self.validate_accumulator_computation(combiner, data, expected_accumulator)
|
||||||
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
||||||
|
|
||||||
|
def test_1d_data(self):
|
||||||
|
data = [1, 2, 3]
|
||||||
|
cls = get_layer_class()
|
||||||
|
layer = cls()
|
||||||
|
layer.adapt(data)
|
||||||
|
output = layer(data)
|
||||||
|
self.assertListEqual(output.shape.as_list(), [3, 4])
|
||||||
|
|
||||||
|
def test_no_adapt_exception(self):
|
||||||
|
cls = get_layer_class()
|
||||||
|
layer = cls()
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, r".*you need to call.*"):
|
||||||
|
_ = layer([1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
@ -1037,6 +1037,7 @@ class TextVectorizationOutputTest(
|
||||||
split=None,
|
split=None,
|
||||||
output_mode=text_vectorization.BINARY,
|
output_mode=text_vectorization.BINARY,
|
||||||
pad_to_max_tokens=False)
|
pad_to_max_tokens=False)
|
||||||
|
layer.adapt(vocab_data)
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
|
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
|
||||||
layer.set_vocabulary(vocab_data)
|
layer.set_vocabulary(vocab_data)
|
||||||
|
@ -1054,6 +1055,7 @@ class TextVectorizationOutputTest(
|
||||||
split=None,
|
split=None,
|
||||||
output_mode=text_vectorization.BINARY,
|
output_mode=text_vectorization.BINARY,
|
||||||
pad_to_max_tokens=False)
|
pad_to_max_tokens=False)
|
||||||
|
layer.adapt(vocab_data)
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"):
|
with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"):
|
||||||
layer.adapt(vocab_data)
|
layer.adapt(vocab_data)
|
||||||
|
@ -1070,6 +1072,7 @@ class TextVectorizationOutputTest(
|
||||||
split=None,
|
split=None,
|
||||||
output_mode=text_vectorization.BINARY,
|
output_mode=text_vectorization.BINARY,
|
||||||
pad_to_max_tokens=False)
|
pad_to_max_tokens=False)
|
||||||
|
layer.adapt(["earth", "wind"])
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
|
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
|
||||||
layer._set_state_variables(state_variables)
|
layer._set_state_variables(state_variables)
|
||||||
|
|
Loading…
Reference in New Issue