Support a mask_value argument for preprocessing.hashing
This argument will allow specifying a value which should always map to the zero index. For now, this will only be supported for a single tensor input as the desired behavior when crossing multiple inputs is unclear. PiperOrigin-RevId: 351904657 Change-Id: I8ae3fd88ef94f7b1244cd1a6da7adbc2a40dfef1
This commit is contained in:
parent
73a6839fb3
commit
5fb1d0e838
tensorflow
python/keras/layers/preprocessing
tools/api/golden
@ -97,8 +97,10 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
@ -106,8 +108,6 @@ py_library(
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
"//tensorflow/python/ops/ragged:ragged_functional_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -28,11 +28,11 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_sparse_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -71,6 +71,19 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
|
||||
[1],
|
||||
[2]])>
|
||||
|
||||
Example (FarmHash64) with a mask value:
|
||||
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
|
||||
... mask_value='')
|
||||
>>> inp = [['A'], ['B'], [''], ['C'], ['D']]
|
||||
>>> layer(inp)
|
||||
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
|
||||
array([[1],
|
||||
[1],
|
||||
[0],
|
||||
[2],
|
||||
[2]])>
|
||||
|
||||
|
||||
Example (FarmHash64) with list of inputs:
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
|
||||
@ -114,7 +127,12 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
|
||||
Reference: [SipHash with salt](https://www.131002.net/siphash/siphash.pdf)
|
||||
|
||||
Args:
|
||||
num_bins: Number of hash bins.
|
||||
num_bins: Number of hash bins. Note that this includes the `mask_value` bin,
|
||||
so the effective number of bins is `(num_bins - 1)` if `mask_value` is
|
||||
set.
|
||||
mask_value: A value that represents masked inputs, which are mapped to
|
||||
index 0. Defaults to None, meaning no mask term will be added and the
|
||||
hashing will start at index 0.
|
||||
salt: A single unsigned integer or None.
|
||||
If passed, the hash function used will be SipHash64, with these values
|
||||
used as an additional input (known as a "salt" in cryptography).
|
||||
@ -134,12 +152,13 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_bins, salt=None, name=None, **kwargs):
|
||||
def __init__(self, num_bins, mask_value=None, salt=None, name=None, **kwargs):
|
||||
if num_bins is None or num_bins <= 0:
|
||||
raise ValueError('`num_bins` cannot be `None` or non-positive values.')
|
||||
super(Hashing, self).__init__(name=name, **kwargs)
|
||||
base_preprocessing_layer.keras_kpl_gauge.get_cell('Hashing').set(True)
|
||||
self.num_bins = num_bins
|
||||
self.mask_value = mask_value
|
||||
self.strong_hash = True if salt is not None else False
|
||||
if salt is not None:
|
||||
if isinstance(salt, (tuple, list)) and len(salt) == 2:
|
||||
@ -170,39 +189,22 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
|
||||
inputs = self._preprocess_inputs(inputs)
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
return self._process_input_list(inputs)
|
||||
else:
|
||||
return self._process_single_input(inputs)
|
||||
|
||||
def _process_single_input(self, inputs):
|
||||
# Converts integer inputs to string.
|
||||
if inputs.dtype.is_integer:
|
||||
if isinstance(inputs, sparse_tensor.SparseTensor):
|
||||
inputs = sparse_tensor.SparseTensor(
|
||||
indices=inputs.indices,
|
||||
values=string_ops.as_string(inputs.values),
|
||||
dense_shape=inputs.dense_shape)
|
||||
else:
|
||||
inputs = string_ops.as_string(inputs)
|
||||
str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
|
||||
if tf_utils.is_ragged(inputs):
|
||||
return ragged_functional_ops.map_flat_values(
|
||||
str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash')
|
||||
elif isinstance(inputs, sparse_tensor.SparseTensor):
|
||||
sparse_values = inputs.values
|
||||
sparse_hashed_values = str_to_hash_bucket(
|
||||
sparse_values, self.num_bins, name='hash')
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=inputs.indices,
|
||||
values=sparse_hashed_values,
|
||||
values=self._hash_values_to_bins(inputs.values),
|
||||
dense_shape=inputs.dense_shape)
|
||||
else:
|
||||
return str_to_hash_bucket(inputs, self.num_bins, name='hash')
|
||||
return self._hash_values_to_bins(inputs)
|
||||
|
||||
def _process_input_list(self, inputs):
|
||||
# TODO(momernick): support ragged_cross_hashed with corrected fingerprint
|
||||
# and siphash.
|
||||
if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs):
|
||||
raise ValueError('Hashing with ragged input is not supported yet.')
|
||||
if self.mask_value is not None:
|
||||
raise ValueError(
|
||||
'Cross hashing with a mask_value is not supported yet, mask_value is '
|
||||
'{}.'.format(self.mask_value))
|
||||
sparse_inputs = [
|
||||
inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor)
|
||||
]
|
||||
@ -226,6 +228,24 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
|
||||
return sparse_ops.sparse_tensor_to_dense(sparse_out)
|
||||
return sparse_out
|
||||
|
||||
def _hash_values_to_bins(self, values):
|
||||
"""Converts a non-sparse tensor of values to bin indices."""
|
||||
str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
|
||||
num_available_bins = self.num_bins
|
||||
mask = None
|
||||
# If mask_value is set, the zeroth bin is reserved for it.
|
||||
if self.mask_value is not None and num_available_bins > 1:
|
||||
num_available_bins -= 1
|
||||
mask = math_ops.equal(values, self.mask_value)
|
||||
# Convert all values to strings before hashing.
|
||||
if values.dtype.is_integer:
|
||||
values = string_ops.as_string(values)
|
||||
values = str_to_hash_bucket(values, num_available_bins, name='hash')
|
||||
if mask is not None:
|
||||
values = math_ops.add(values, array_ops.ones_like(values))
|
||||
values = array_ops.where(mask, array_ops.zeros_like(values), values)
|
||||
return values
|
||||
|
||||
def _get_string_to_hash_bucket_fn(self):
|
||||
"""Returns the string_to_hash_bucket op to use based on `hasher_key`."""
|
||||
# string_to_hash_bucket_fast uses FarmHash64 as hash function.
|
||||
@ -274,6 +294,10 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer):
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
|
||||
|
||||
def get_config(self):
|
||||
config = {'num_bins': self.num_bins, 'salt': self.salt}
|
||||
config = {
|
||||
'num_bins': self.num_bins,
|
||||
'salt': self.salt,
|
||||
'mask_value': self.mask_value,
|
||||
}
|
||||
base_config = super(Hashing, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
@ -51,6 +51,27 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[0], [0], [1], [0], [0]], output)
|
||||
|
||||
def test_hash_dense_input_mask_value_farmhash(self):
|
||||
empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
|
||||
omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
|
||||
inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
['skywalker']])
|
||||
empty_mask_output = empty_mask_layer(inp)
|
||||
omar_mask_output = omar_mask_layer(inp)
|
||||
# Outputs should be one more than test_hash_dense_input_farmhash (the zeroth
|
||||
# bin is now reserved for masks).
|
||||
self.assertAllClose([[1], [1], [2], [1], [1]], empty_mask_output)
|
||||
# 'omar' should map to 0.
|
||||
self.assertAllClose([[0], [1], [2], [1], [1]], omar_mask_output)
|
||||
|
||||
def test_hash_dense_multi_inputs_mask_value_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=3, mask_value='omar')
|
||||
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
['skywalker']])
|
||||
inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
with self.assertRaisesRegex(ValueError, 'not supported yet'):
|
||||
_ = layer([inp_1, inp_2])
|
||||
|
||||
def test_hash_dense_multi_inputs_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
@ -135,6 +156,24 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
self.assertAllClose(indices, output.indices)
|
||||
self.assertAllClose([0, 0, 1, 0, 0], output.values)
|
||||
|
||||
def test_hash_sparse_input_mask_value_farmhash(self):
|
||||
empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
|
||||
omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
inp = sparse_tensor.SparseTensor(
|
||||
indices=indices,
|
||||
values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
|
||||
dense_shape=[3, 2])
|
||||
empty_mask_output = empty_mask_layer(inp)
|
||||
omar_mask_output = omar_mask_layer(inp)
|
||||
self.assertAllClose(indices, omar_mask_output.indices)
|
||||
self.assertAllClose(indices, empty_mask_output.indices)
|
||||
# Outputs should be one more than test_hash_sparse_input_farmhash (the
|
||||
# zeroth bin is now reserved for masks).
|
||||
self.assertAllClose([1, 1, 2, 1, 1], empty_mask_output.values)
|
||||
# 'omar' should map to 0.
|
||||
self.assertAllClose([0, 1, 2, 1, 1], omar_mask_output.values)
|
||||
|
||||
def test_hash_sparse_multi_inputs_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
indices = [[0, 0], [1, 0], [2, 0]]
|
||||
@ -217,6 +256,22 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_hash_ragged_input_mask_value(self):
|
||||
empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
|
||||
omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
|
||||
inp_data = ragged_factory_ops.constant(
|
||||
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||
dtype=dtypes.string)
|
||||
empty_mask_output = empty_mask_layer(inp_data)
|
||||
omar_mask_output = omar_mask_layer(inp_data)
|
||||
# Outputs should be one more than test_hash_ragged_string_input_farmhash
|
||||
# (the zeroth bin is now reserved for masks).
|
||||
expected_output = [[1, 1, 2, 1], [2, 1, 1]]
|
||||
self.assertAllClose(expected_output, empty_mask_output)
|
||||
# 'omar' should map to 0.
|
||||
expected_output = [[0, 1, 2, 1], [2, 1, 1]]
|
||||
self.assertAllClose(expected_output, omar_mask_output)
|
||||
|
||||
def test_hash_ragged_string_multi_inputs_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp_data_1 = ragged_factory_ops.constant(
|
||||
|
@ -130,7 +130,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "adapt"
|
||||
|
@ -130,7 +130,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "adapt"
|
||||
|
Loading…
Reference in New Issue
Block a user