Support a mask_value argument for preprocessing.hashing

This argument will allow specifying a value which should always map to the zero
index. For now, this will only be supported for a single tensor input as the
desired behavior when crossing multiple inputs is unclear.

PiperOrigin-RevId: 351904657
Change-Id: I8ae3fd88ef94f7b1244cd1a6da7adbc2a40dfef1
This commit is contained in:
Matt Watson 2021-01-14 16:48:41 -08:00 committed by TensorFlower Gardener
parent 73a6839fb3
commit 5fb1d0e838
5 changed files with 111 additions and 32 deletions

View File

@ -97,8 +97,10 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
@ -106,8 +108,6 @@ py_library(
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tensor_util",
"//tensorflow/python/keras/engine",
"//tensorflow/python/keras/utils:tf_utils",
"//tensorflow/python/ops/ragged:ragged_functional_ops",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/util:tf_export",
"//third_party/py/numpy",

View File

@ -28,11 +28,11 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util.tf_export import keras_export
@ -71,6 +71,19 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
[1],
[2]])>
Example (FarmHash64) with a mask value:
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
... mask_value='')
>>> inp = [['A'], ['B'], [''], ['C'], ['D']]
>>> layer(inp)
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
array([[1],
[1],
[0],
[2],
[2]])>
Example (FarmHash64) with list of inputs:
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
@ -114,7 +127,12 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
Reference: [SipHash with salt](https://www.131002.net/siphash/siphash.pdf)
Args:
num_bins: Number of hash bins.
num_bins: Number of hash bins. Note that this includes the `mask_value` bin,
so the effective number of bins is `(num_bins - 1)` if `mask_value` is
set.
mask_value: A value that represents masked inputs, which are mapped to
index 0. Defaults to None, meaning no mask term will be added and the
hashing will start at index 0.
salt: A single unsigned integer or None.
If passed, the hash function used will be SipHash64, with these values
used as an additional input (known as a "salt" in cryptography).
@ -134,12 +152,13 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
"""
def __init__(self, num_bins, salt=None, name=None, **kwargs):
def __init__(self, num_bins, mask_value=None, salt=None, name=None, **kwargs):
if num_bins is None or num_bins <= 0:
raise ValueError('`num_bins` cannot be `None` or non-positive values.')
super(Hashing, self).__init__(name=name, **kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell('Hashing').set(True)
self.num_bins = num_bins
self.mask_value = mask_value
self.strong_hash = True if salt is not None else False
if salt is not None:
if isinstance(salt, (tuple, list)) and len(salt) == 2:
@ -170,39 +189,22 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
inputs = self._preprocess_inputs(inputs)
if isinstance(inputs, (tuple, list)):
return self._process_input_list(inputs)
else:
return self._process_single_input(inputs)
def _process_single_input(self, inputs):
# Converts integer inputs to string.
if inputs.dtype.is_integer:
if isinstance(inputs, sparse_tensor.SparseTensor):
inputs = sparse_tensor.SparseTensor(
indices=inputs.indices,
values=string_ops.as_string(inputs.values),
dense_shape=inputs.dense_shape)
else:
inputs = string_ops.as_string(inputs)
str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
if tf_utils.is_ragged(inputs):
return ragged_functional_ops.map_flat_values(
str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash')
elif isinstance(inputs, sparse_tensor.SparseTensor):
sparse_values = inputs.values
sparse_hashed_values = str_to_hash_bucket(
sparse_values, self.num_bins, name='hash')
return sparse_tensor.SparseTensor(
indices=inputs.indices,
values=sparse_hashed_values,
values=self._hash_values_to_bins(inputs.values),
dense_shape=inputs.dense_shape)
else:
return str_to_hash_bucket(inputs, self.num_bins, name='hash')
return self._hash_values_to_bins(inputs)
def _process_input_list(self, inputs):
# TODO(momernick): support ragged_cross_hashed with corrected fingerprint
# and siphash.
if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs):
raise ValueError('Hashing with ragged input is not supported yet.')
if self.mask_value is not None:
raise ValueError(
'Cross hashing with a mask_value is not supported yet, mask_value is '
'{}.'.format(self.mask_value))
sparse_inputs = [
inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor)
]
@ -226,6 +228,24 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
return sparse_ops.sparse_tensor_to_dense(sparse_out)
return sparse_out
def _hash_values_to_bins(self, values):
"""Converts a non-sparse tensor of values to bin indices."""
str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
num_available_bins = self.num_bins
mask = None
# If mask_value is set, the zeroth bin is reserved for it.
if self.mask_value is not None and num_available_bins > 1:
num_available_bins -= 1
mask = math_ops.equal(values, self.mask_value)
# Convert all values to strings before hashing.
if values.dtype.is_integer:
values = string_ops.as_string(values)
values = str_to_hash_bucket(values, num_available_bins, name='hash')
if mask is not None:
values = math_ops.add(values, array_ops.ones_like(values))
values = array_ops.where(mask, array_ops.zeros_like(values), values)
return values
def _get_string_to_hash_bucket_fn(self):
"""Returns the string_to_hash_bucket op to use based on `hasher_key`."""
# string_to_hash_bucket_fast uses FarmHash64 as hash function.
@ -274,6 +294,10 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
def get_config(self):
config = {'num_bins': self.num_bins, 'salt': self.salt}
config = {
'num_bins': self.num_bins,
'salt': self.salt,
'mask_value': self.mask_value,
}
base_config = super(Hashing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -51,6 +51,27 @@ class HashingTest(keras_parameterized.TestCase):
# Assert equal for hashed output that should be true on all platforms.
self.assertAllClose([[0], [0], [1], [0], [0]], output)
def test_hash_dense_input_mask_value_farmhash(self):
empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
['skywalker']])
empty_mask_output = empty_mask_layer(inp)
omar_mask_output = omar_mask_layer(inp)
# Outputs should be one more than test_hash_dense_input_farmhash (the zeroth
# bin is now reserved for masks).
self.assertAllClose([[1], [1], [2], [1], [1]], empty_mask_output)
# 'omar' should map to 0.
self.assertAllClose([[0], [1], [2], [1], [1]], omar_mask_output)
def test_hash_dense_multi_inputs_mask_value_farmhash(self):
layer = hashing.Hashing(num_bins=3, mask_value='omar')
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
['skywalker']])
inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
with self.assertRaisesRegex(ValueError, 'not supported yet'):
_ = layer([inp_1, inp_2])
def test_hash_dense_multi_inputs_farmhash(self):
layer = hashing.Hashing(num_bins=2)
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
@ -135,6 +156,24 @@ class HashingTest(keras_parameterized.TestCase):
self.assertAllClose(indices, output.indices)
self.assertAllClose([0, 0, 1, 0, 0], output.values)
def test_hash_sparse_input_mask_value_farmhash(self):
empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
inp = sparse_tensor.SparseTensor(
indices=indices,
values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
dense_shape=[3, 2])
empty_mask_output = empty_mask_layer(inp)
omar_mask_output = omar_mask_layer(inp)
self.assertAllClose(indices, omar_mask_output.indices)
self.assertAllClose(indices, empty_mask_output.indices)
# Outputs should be one more than test_hash_sparse_input_farmhash (the
# zeroth bin is now reserved for masks).
self.assertAllClose([1, 1, 2, 1, 1], empty_mask_output.values)
# 'omar' should map to 0.
self.assertAllClose([0, 1, 2, 1, 1], omar_mask_output.values)
def test_hash_sparse_multi_inputs_farmhash(self):
layer = hashing.Hashing(num_bins=2)
indices = [[0, 0], [1, 0], [2, 0]]
@ -217,6 +256,22 @@ class HashingTest(keras_parameterized.TestCase):
model = training.Model(inputs=inp_t, outputs=out_t)
self.assertAllClose(out_data, model.predict(inp_data))
def test_hash_ragged_input_mask_value(self):
empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
inp_data = ragged_factory_ops.constant(
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
dtype=dtypes.string)
empty_mask_output = empty_mask_layer(inp_data)
omar_mask_output = omar_mask_layer(inp_data)
# Outputs should be one more than test_hash_ragged_string_input_farmhash
# (the zeroth bin is now reserved for masks).
expected_output = [[1, 1, 2, 1], [2, 1, 1]]
self.assertAllClose(expected_output, empty_mask_output)
# 'omar' should map to 0.
expected_output = [[0, 1, 2, 1], [2, 1, 1]]
self.assertAllClose(expected_output, omar_mask_output)
def test_hash_ragged_string_multi_inputs_farmhash(self):
layer = hashing.Hashing(num_bins=2)
inp_data_1 = ragged_factory_ops.constant(

View File

@ -130,7 +130,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "adapt"

View File

@ -130,7 +130,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "adapt"