Make CategoryEncoding work with 1D inputs and list inputs.
Also add error message when passing max_tokens=None and not calling adapt. Note that the 1D fix won't apply to SparseTensor and RaggedTensor at this time. PiperOrigin-RevId: 315803768 Change-Id: I7d302a2a9009ad63db3c5fb6a4209f63da8f2635
This commit is contained in:
parent
5d1368571f
commit
a745f0a953
|
@ -267,12 +267,22 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||
K.set_value(self.tf_idf_weights, tfidf_data)
|
||||
|
||||
def call(self, inputs, count_weights=None):
|
||||
if isinstance(inputs, (list, np.ndarray)):
|
||||
inputs = ops.convert_to_tensor_v2(inputs)
|
||||
if inputs.shape.rank == 1:
|
||||
inputs = array_ops.expand_dims(inputs, 1)
|
||||
|
||||
if count_weights is not None and self._output_mode != COUNT:
|
||||
raise ValueError("count_weights is not used in `output_mode='tf-idf'`, "
|
||||
"or `output_mode='binary'`. Please pass a single input.")
|
||||
self._called = True
|
||||
if self._max_tokens is None:
|
||||
out_depth = K.get_value(self.num_elements)
|
||||
if out_depth == 0:
|
||||
raise RuntimeError(
|
||||
"If you construct a `CategoryEncoding` layer with "
|
||||
"`max_tokens=None`, you need to call `adapt()` "
|
||||
"on it before using it")
|
||||
else:
|
||||
out_depth = self._max_tokens
|
||||
|
||||
|
@ -352,6 +362,8 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner):
|
|||
|
||||
# TODO(momernick): Benchmark improvements to this algorithm.
|
||||
for element in values:
|
||||
if not isinstance(element, list):
|
||||
element = [element]
|
||||
current_doc_id = accumulator.data[self.DOC_ID_IDX]
|
||||
for value in element:
|
||||
current_max_value = accumulator.data[self.MAX_VALUE_IDX]
|
||||
|
|
|
@ -405,6 +405,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
|
|||
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=None, output_mode=category_encoding.BINARY)
|
||||
layer.adapt([1, 2])
|
||||
_ = layer(input_data)
|
||||
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
|
||||
layer.set_num_elements(5)
|
||||
|
@ -415,6 +416,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
|
|||
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=None, output_mode=category_encoding.BINARY)
|
||||
layer.adapt(vocab_data)
|
||||
_ = layer(input_data)
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be adapted"):
|
||||
layer.adapt(vocab_data)
|
||||
|
@ -425,6 +427,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
|
|||
input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=None, output_mode=category_encoding.BINARY)
|
||||
layer.adapt([1, 2])
|
||||
_ = layer(input_data)
|
||||
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
|
||||
layer._set_state_variables(state_variables)
|
||||
|
@ -741,6 +744,21 @@ class CategoryEncodingCombinerTest(
|
|||
self.validate_accumulator_computation(combiner, data, expected_accumulator)
|
||||
self.validate_accumulator_extract(combiner, data, expected_extract_output)
|
||||
|
||||
def test_1d_data(self):
|
||||
data = [1, 2, 3]
|
||||
cls = get_layer_class()
|
||||
layer = cls()
|
||||
layer.adapt(data)
|
||||
output = layer(data)
|
||||
self.assertListEqual(output.shape.as_list(), [3, 4])
|
||||
|
||||
def test_no_adapt_exception(self):
|
||||
cls = get_layer_class()
|
||||
layer = cls()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r".*you need to call.*"):
|
||||
_ = layer([1, 2, 3])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
@ -1037,6 +1037,7 @@ class TextVectorizationOutputTest(
|
|||
split=None,
|
||||
output_mode=text_vectorization.BINARY,
|
||||
pad_to_max_tokens=False)
|
||||
layer.adapt(vocab_data)
|
||||
_ = layer(input_data)
|
||||
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
|
||||
layer.set_vocabulary(vocab_data)
|
||||
|
@ -1054,6 +1055,7 @@ class TextVectorizationOutputTest(
|
|||
split=None,
|
||||
output_mode=text_vectorization.BINARY,
|
||||
pad_to_max_tokens=False)
|
||||
layer.adapt(vocab_data)
|
||||
_ = layer(input_data)
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"):
|
||||
layer.adapt(vocab_data)
|
||||
|
@ -1070,6 +1072,7 @@ class TextVectorizationOutputTest(
|
|||
split=None,
|
||||
output_mode=text_vectorization.BINARY,
|
||||
pad_to_max_tokens=False)
|
||||
layer.adapt(["earth", "wind"])
|
||||
_ = layer(input_data)
|
||||
with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"):
|
||||
layer._set_state_variables(state_variables)
|
||||
|
|
Loading…
Reference in New Issue