Support weights for CategoryEncoding.

PiperOrigin-RevId: 313682327
Change-Id: I35b38d66ce5a429cff4ca7a178f13c6649b2027b
This commit is contained in:
Zhenyu Tan 2020-05-28 16:33:23 -07:00 committed by TensorFlower Gardener
parent eb40237008
commit 2003db55b1
4 changed files with 85 additions and 4 deletions

View File

@ -71,6 +71,20 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
[0, 1, 1, 0], [0, 1, 1, 0],
[0, 1, 0, 1]])> [0, 1, 0, 1]])>
Examples with weighted inputs:
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
... max_tokens=4)
>>> count_weights = np.array([[.1, .2], [.1, .1], [.2, .3], [.4, .2]])
>>> layer([[0, 1], [0, 0], [1, 2], [3, 1]], count_weights=count_weights)
<tf.Tensor: shape=(4, 4), dtype=float64, numpy=
array([[0.1, 0.2, 0. , 0. ],
[0.2, 0. , 0. , 0. ],
[0. , 0.2, 0.3, 0. ],
[0. , 0.2, 0. , 0.4]])>
Attributes: Attributes:
max_tokens: The maximum size of the vocabulary for this layer. If None, max_tokens: The maximum size of the vocabulary for this layer. If None,
there is no cap on the size of the vocabulary. there is no cap on the size of the vocabulary.
@ -85,6 +99,12 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
value in each token slot. value in each token slot.
sparse: Boolean. If true, returns a `SparseTensor` instead of a dense sparse: Boolean. If true, returns a `SparseTensor` instead of a dense
`Tensor`. Defaults to `False`. `Tensor`. Defaults to `False`.
Call arguments:
inputs: A 2D tensor `(samples, timesteps)`.
count_weights: A 2D tensor in the same shape as `inputs` indicating the
weight for each sample value when summing up in `count` mode. Not used in
`binary` or `tfidf` mode.
""" """
def __init__(self, def __init__(self,
@ -242,7 +262,10 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
tfidf_data = np.resize(tfidf_data, (self._max_tokens,)) tfidf_data = np.resize(tfidf_data, (self._max_tokens,))
K.set_value(self.tf_idf_weights, tfidf_data) K.set_value(self.tf_idf_weights, tfidf_data)
def call(self, inputs): def call(self, inputs, count_weights=None):
if count_weights is not None and self._output_mode != COUNT:
raise ValueError("count_weights is not used in `output_mode='tf-idf'`, "
"or `output_mode='binary'`. Please pass a single input.")
self._called = True self._called = True
if self._max_tokens is None: if self._max_tokens is None:
out_depth = K.get_value(self.num_elements) out_depth = K.get_value(self.num_elements)
@ -264,10 +287,15 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
binary_output = (self._output_mode == BINARY) binary_output = (self._output_mode == BINARY)
if self._sparse: if self._sparse:
return bincount_ops.sparse_bincount( return bincount_ops.sparse_bincount(
inputs, minlength=out_depth, axis=-1, binary_output=binary_output) inputs,
weights=count_weights,
minlength=out_depth,
axis=-1,
binary_output=binary_output)
else: else:
result = bincount_ops.bincount( result = bincount_ops.bincount(
inputs, inputs,
weights=count_weights,
minlength=out_depth, minlength=out_depth,
dtype=dtypes.int64, dtype=dtypes.int64,
axis=-1, axis=-1,

View File

@ -109,6 +109,32 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase,
output_dataset = model.predict(sparse_tensor_data, steps=1) output_dataset = model.predict(sparse_tensor_data, steps=1)
self.assertAllEqual(expected_output, output_dataset) self.assertAllEqual(expected_output, output_dataset)
def test_sparse_input_with_weights(self):
input_array = np.array([[1, 2, 3, 4], [4, 3, 1, 4]], dtype=np.int64)
weights_array = np.array([[.1, .2, .3, .4], [.2, .1, .4, .3]])
sparse_tensor_data = sparse_ops.from_dense(input_array)
sparse_weight_data = sparse_ops.from_dense(weights_array)
# pyformat: disable
expected_output = [[0, .1, .2, .3, .4, 0],
[0, .4, 0, .1, .5, 0]]
# pyformat: enable
max_tokens = 6
expected_output_shape = [None, max_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
weight_data = keras.Input(shape=(None,), dtype=dtypes.float32, sparse=True)
layer = get_layer_class()(
max_tokens=max_tokens, output_mode=category_encoding.COUNT)
int_data = layer(input_data, count_weights=weight_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
model = keras.Model(inputs=[input_data, weight_data], outputs=int_data)
output_dataset = model.predict([sparse_tensor_data, sparse_weight_data],
steps=1)
self.assertAllClose(expected_output, output_dataset)
def test_sparse_input_sparse_output(self): def test_sparse_input_sparse_output(self):
sp_inp = sparse_tensor.SparseTensor( sp_inp = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 1], [2, 0], [2, 1], [3, 1]], indices=[[0, 0], [1, 1], [2, 0], [2, 1], [3, 1]],
@ -146,6 +172,33 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase,
sparse_ops.sparse_tensor_to_dense(sp_output_dataset, default_value=0), sparse_ops.sparse_tensor_to_dense(sp_output_dataset, default_value=0),
output_dataset) output_dataset)
def test_sparse_input_sparse_output_with_weights(self):
indices = [[0, 0], [1, 1], [2, 0], [2, 1], [3, 1]]
sp_inp = sparse_tensor.SparseTensor(
indices=indices, values=[0, 2, 1, 1, 0], dense_shape=[4, 2])
input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
sp_weight = sparse_tensor.SparseTensor(
indices=indices, values=[.1, .2, .4, .3, .2], dense_shape=[4, 2])
weight_data = keras.Input(shape=(None,), dtype=dtypes.float32, sparse=True)
# The expected output should be (X for missing value):
# [[1, X, X, X]
# [X, X, 1, X]
# [X, 2, X, X]
# [1, X, X, X]]
expected_indices = [[0, 0], [1, 2], [2, 1], [3, 0]]
expected_values = [.1, .2, .7, .2]
max_tokens = 6
layer = get_layer_class()(
max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
int_data = layer(input_data, count_weights=weight_data)
model = keras.Model(inputs=[input_data, weight_data], outputs=int_data)
sp_output_dataset = model.predict([sp_inp, sp_weight], steps=1)
self.assertAllClose(expected_values, sp_output_dataset.values)
self.assertAllEqual(expected_indices, sp_output_dataset.indices)
def test_ragged_input(self): def test_ragged_input(self):
input_array = ragged_factory_ops.constant([[1, 2, 3], [3, 1]]) input_array = ragged_factory_ops.constant([[1, 2, 3], [3, 1]])

View File

@ -153,7 +153,7 @@ tf_class {
} }
member_method { member_method {
name: "call" name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'inputs\', \'count_weights\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "compute_mask" name: "compute_mask"

View File

@ -151,7 +151,7 @@ tf_class {
} }
member_method { member_method {
name: "call" name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'inputs\', \'count_weights\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "compute_mask" name: "compute_mask"