Support weights for CategoryEncoding.
PiperOrigin-RevId: 313682327 Change-Id: I35b38d66ce5a429cff4ca7a178f13c6649b2027b
This commit is contained in:
parent
eb40237008
commit
2003db55b1
@ -71,6 +71,20 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
[0, 1, 1, 0],
|
||||
[0, 1, 0, 1]])>
|
||||
|
||||
|
||||
Examples with weighted inputs:
|
||||
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
|
||||
... max_tokens=4)
|
||||
>>> count_weights = np.array([[.1, .2], [.1, .1], [.2, .3], [.4, .2]])
|
||||
>>> layer([[0, 1], [0, 0], [1, 2], [3, 1]], count_weights=count_weights)
|
||||
<tf.Tensor: shape=(4, 4), dtype=float64, numpy=
|
||||
array([[0.1, 0.2, 0. , 0. ],
|
||||
[0.2, 0. , 0. , 0. ],
|
||||
[0. , 0.2, 0.3, 0. ],
|
||||
[0. , 0.2, 0. , 0.4]])>
|
||||
|
||||
|
||||
Attributes:
|
||||
max_tokens: The maximum size of the vocabulary for this layer. If None,
|
||||
there is no cap on the size of the vocabulary.
|
||||
@ -85,6 +99,12 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
value in each token slot.
|
||||
sparse: Boolean. If true, returns a `SparseTensor` instead of a dense
|
||||
`Tensor`. Defaults to `False`.
|
||||
|
||||
Call arguments:
|
||||
inputs: A 2D tensor `(samples, timesteps)`.
|
||||
count_weights: A 2D tensor in the same shape as `inputs` indicating the
|
||||
weight for each sample value when summing up in `count` mode. Not used in
|
||||
`binary` or `tfidf` mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -242,7 +262,10 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
tfidf_data = np.resize(tfidf_data, (self._max_tokens,))
|
||||
K.set_value(self.tf_idf_weights, tfidf_data)
|
||||
|
||||
def call(self, inputs):
|
||||
def call(self, inputs, count_weights=None):
|
||||
if count_weights is not None and self._output_mode != COUNT:
|
||||
raise ValueError("count_weights is not used in `output_mode='tf-idf'`, "
|
||||
"or `output_mode='binary'`. Please pass a single input.")
|
||||
self._called = True
|
||||
if self._max_tokens is None:
|
||||
out_depth = K.get_value(self.num_elements)
|
||||
@ -264,10 +287,15 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
binary_output = (self._output_mode == BINARY)
|
||||
if self._sparse:
|
||||
return bincount_ops.sparse_bincount(
|
||||
inputs, minlength=out_depth, axis=-1, binary_output=binary_output)
|
||||
inputs,
|
||||
weights=count_weights,
|
||||
minlength=out_depth,
|
||||
axis=-1,
|
||||
binary_output=binary_output)
|
||||
else:
|
||||
result = bincount_ops.bincount(
|
||||
inputs,
|
||||
weights=count_weights,
|
||||
minlength=out_depth,
|
||||
dtype=dtypes.int64,
|
||||
axis=-1,
|
||||
|
@ -109,6 +109,32 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase,
|
||||
output_dataset = model.predict(sparse_tensor_data, steps=1)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_sparse_input_with_weights(self):
|
||||
input_array = np.array([[1, 2, 3, 4], [4, 3, 1, 4]], dtype=np.int64)
|
||||
weights_array = np.array([[.1, .2, .3, .4], [.2, .1, .4, .3]])
|
||||
sparse_tensor_data = sparse_ops.from_dense(input_array)
|
||||
sparse_weight_data = sparse_ops.from_dense(weights_array)
|
||||
|
||||
# pyformat: disable
|
||||
expected_output = [[0, .1, .2, .3, .4, 0],
|
||||
[0, .4, 0, .1, .5, 0]]
|
||||
# pyformat: enable
|
||||
max_tokens = 6
|
||||
expected_output_shape = [None, max_tokens]
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
|
||||
weight_data = keras.Input(shape=(None,), dtype=dtypes.float32, sparse=True)
|
||||
|
||||
layer = get_layer_class()(
|
||||
max_tokens=max_tokens, output_mode=category_encoding.COUNT)
|
||||
int_data = layer(input_data, count_weights=weight_data)
|
||||
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
|
||||
|
||||
model = keras.Model(inputs=[input_data, weight_data], outputs=int_data)
|
||||
output_dataset = model.predict([sparse_tensor_data, sparse_weight_data],
|
||||
steps=1)
|
||||
self.assertAllClose(expected_output, output_dataset)
|
||||
|
||||
def test_sparse_input_sparse_output(self):
|
||||
sp_inp = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 1], [2, 0], [2, 1], [3, 1]],
|
||||
@ -146,6 +172,33 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase,
|
||||
sparse_ops.sparse_tensor_to_dense(sp_output_dataset, default_value=0),
|
||||
output_dataset)
|
||||
|
||||
def test_sparse_input_sparse_output_with_weights(self):
|
||||
indices = [[0, 0], [1, 1], [2, 0], [2, 1], [3, 1]]
|
||||
sp_inp = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=[0, 2, 1, 1, 0], dense_shape=[4, 2])
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
|
||||
sp_weight = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=[.1, .2, .4, .3, .2], dense_shape=[4, 2])
|
||||
weight_data = keras.Input(shape=(None,), dtype=dtypes.float32, sparse=True)
|
||||
|
||||
# The expected output should be (X for missing value):
|
||||
# [[1, X, X, X]
|
||||
# [X, X, 1, X]
|
||||
# [X, 2, X, X]
|
||||
# [1, X, X, X]]
|
||||
expected_indices = [[0, 0], [1, 2], [2, 1], [3, 0]]
|
||||
expected_values = [.1, .2, .7, .2]
|
||||
max_tokens = 6
|
||||
|
||||
layer = get_layer_class()(
|
||||
max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
|
||||
int_data = layer(input_data, count_weights=weight_data)
|
||||
|
||||
model = keras.Model(inputs=[input_data, weight_data], outputs=int_data)
|
||||
sp_output_dataset = model.predict([sp_inp, sp_weight], steps=1)
|
||||
self.assertAllClose(expected_values, sp_output_dataset.values)
|
||||
self.assertAllEqual(expected_indices, sp_output_dataset.indices)
|
||||
|
||||
def test_ragged_input(self):
|
||||
input_array = ragged_factory_ops.constant([[1, 2, 3], [3, 1]])
|
||||
|
||||
|
@ -153,7 +153,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'count_weights\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -151,7 +151,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'count_weights\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
Loading…
Reference in New Issue
Block a user