Adding hashing trick for Hashing layer.
PiperOrigin-RevId: 294821319 Change-Id: Id127f5df76311bc3904dd7f68c628b310cbc9e85
This commit is contained in:
parent
f6a8bcd0f2
commit
0d23b37aaf
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -35,16 +37,39 @@ class Hashing(Layer):
|
||||
tensorflow::ops::Fingerprint to produce universal output that is consistent
|
||||
across platforms.
|
||||
|
||||
Usage:
|
||||
This layer uses [FarmHash64](https://github.com/google/farmhash) by default,
|
||||
which provides a consistent hashed output across different platforms and is
|
||||
stable across invocations, regardless of device and context, by mixing the
|
||||
input bits thoroughly.
|
||||
|
||||
If you want to obfuscate the hashed output, you can also pass a random `salt`
|
||||
argument in the constructor. In that case, the layer will use the
|
||||
[SipHash64](https://github.com/google/highwayhash) hash function, with
|
||||
the `salt` value serving as additional input to the hash function.
|
||||
|
||||
Example (FarmHash64):
|
||||
```python
|
||||
layer = Hashing(num_bins=3)
|
||||
inp = np.asarray([['A', 'B'], ['C', 'A']])
|
||||
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
layer(inputs)
|
||||
[[0, 0], [1, 0]]
|
||||
[[1], [0], [1], [1], [2]]
|
||||
```
|
||||
|
||||
Example (SipHash64):
|
||||
```python
|
||||
layer = Hashing(num_bins=3, salt=[133, 137])
|
||||
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
layer(inputs)
|
||||
[[1], [2], [1], [0], [2]]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
num_bins: Number of hash bins.
|
||||
salt: A tuple/list of 2 unsigned integer numbers. If passed, the hash
|
||||
function used will be SipHash64, with these values used as an additional
|
||||
input (known as a "salt" in cryptography).
|
||||
These should be non-zero. Defaults to `None` (in that
|
||||
case, the FarmHash64 hash function is used).
|
||||
name: Name to give to the layer.
|
||||
**kwargs: Keyword arguments to construct a layer.
|
||||
|
||||
@ -53,38 +78,46 @@ class Hashing(Layer):
|
||||
|
||||
Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]`
|
||||
|
||||
Example:
|
||||
If the input is a 5 by 1 string tensor '[['A'], ['B'], ['C'], ['D'], ['E']]'
|
||||
with `num_bins=2`, then output is 5 by 1 integer tensor
|
||||
[[hash('A')], [hash('B')], [hash('C')], [hash('D')], [hash('E')]].
|
||||
"""
|
||||
|
||||
def __init__(self, num_bins, name=None, **kwargs):
|
||||
# TODO(tanzheny): consider adding strong hash variant.
|
||||
def __init__(self, num_bins, salt=None, name=None, **kwargs):
|
||||
if num_bins is None or num_bins <= 0:
|
||||
raise ValueError('`num_bins` cannot be `None` or non-positive values.')
|
||||
if salt is not None:
|
||||
if not isinstance(salt, (tuple, list)) or len(salt) != 2:
|
||||
raise ValueError('`salt` must be a tuple or list of 2 unsigned '
|
||||
'integer numbers, got {}'.format(salt))
|
||||
super(Hashing, self).__init__(name=name, **kwargs)
|
||||
self._num_bins = num_bins
|
||||
self.num_bins = num_bins
|
||||
self.salt = salt
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def call(self, inputs):
|
||||
# TODO(tanzheny): Add int support.
|
||||
# string_to_hash_bucket_fast uses FarmHash as hash function.
|
||||
str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
|
||||
if ragged_tensor.is_ragged(inputs):
|
||||
return ragged_functional_ops.map_flat_values(
|
||||
string_ops.string_to_hash_bucket_fast,
|
||||
inputs,
|
||||
num_buckets=self._num_bins,
|
||||
name='hash')
|
||||
str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash')
|
||||
elif isinstance(inputs, sparse_tensor.SparseTensor):
|
||||
sparse_values = inputs.values
|
||||
sparse_hashed_values = string_ops.string_to_hash_bucket_fast(
|
||||
sparse_values, self._num_bins, name='hash')
|
||||
sparse_hashed_values = str_to_hash_bucket(
|
||||
sparse_values, self.num_bins, name='hash')
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=inputs.indices,
|
||||
values=sparse_hashed_values,
|
||||
dense_shape=inputs.dense_shape)
|
||||
else:
|
||||
return string_ops.string_to_hash_bucket_fast(
|
||||
inputs, self._num_bins, name='hash')
|
||||
return str_to_hash_bucket(inputs, self.num_bins, name='hash')
|
||||
|
||||
def _get_string_to_hash_bucket_fn(self):
|
||||
"""Returns the string_to_hash_bucket op to use based on `hasher_key`."""
|
||||
# string_to_hash_bucket_fast uses FarmHash64 as hash function.
|
||||
if self.salt is None:
|
||||
return string_ops.string_to_hash_bucket_fast
|
||||
# string_to_hash_bucket_strong uses SipHash64 as hash function.
|
||||
else:
|
||||
return functools.partial(
|
||||
string_ops.string_to_hash_bucket_strong, key=self.salt)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
@ -99,6 +132,6 @@ class Hashing(Layer):
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
||||
|
||||
def get_config(self):
|
||||
config = {'num_bins': self._num_bins}
|
||||
config = {'num_bins': self.num_bins, 'salt': self.salt}
|
||||
base_config = super(Hashing, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -40,43 +41,109 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
layer = hashing.Hashing(num_bins=1)
|
||||
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
output = layer(inp)
|
||||
self.assertAllClose(np.asarray([[0], [0], [0], [0], [0]]), output)
|
||||
self.assertAllClose([[0], [0], [0], [0], [0]], output)
|
||||
|
||||
def test_hash_two_bins(self):
|
||||
def test_hash_dense_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
['skywalker']])
|
||||
output = layer(inp)
|
||||
self.assertEqual(output.numpy().max(), 1)
|
||||
self.assertEqual(output.numpy().min(), 0)
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[0], [0], [1], [0], [0]], output)
|
||||
|
||||
def test_hash_sparse_input(self):
|
||||
def test_hash_dense_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
['skywalker']])
|
||||
output = layer(inp)
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
# Note the result is different from FarmHash.
|
||||
self.assertAllClose([[0], [1], [0], [1], [0]], output)
|
||||
|
||||
layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
|
||||
output_2 = layer_2(inp)
|
||||
# Note the result is different from (133, 137).
|
||||
self.assertAllClose([[1], [0], [1], [0], [1]], output_2)
|
||||
|
||||
def test_hash_sparse_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
inp = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]],
|
||||
indices=indices,
|
||||
values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
|
||||
dense_shape=[3, 2])
|
||||
output = layer(inp)
|
||||
self.assertEqual(output.values.numpy().max(), 1)
|
||||
self.assertEqual(output.values.numpy().min(), 0)
|
||||
self.assertAllClose(indices, output.indices)
|
||||
self.assertAllClose([0, 0, 1, 0, 0], output.values)
|
||||
|
||||
def test_hash_ragged_string_input(self):
|
||||
def test_hash_sparse_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
inp = sparse_tensor.SparseTensor(
|
||||
indices=indices,
|
||||
values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
|
||||
dense_shape=[3, 2])
|
||||
output = layer(inp)
|
||||
self.assertAllClose(output.indices, indices)
|
||||
# The result should be same with test_hash_dense_input_siphash.
|
||||
self.assertAllClose([0, 1, 0, 1, 0], output.values)
|
||||
|
||||
layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
|
||||
output = layer_2(inp)
|
||||
# The result should be same with test_hash_dense_input_siphash.
|
||||
self.assertAllClose([1, 0, 1, 0, 1], output.values)
|
||||
|
||||
def test_hash_ragged_string_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp_data = ragged_factory_ops.constant(
|
||||
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||
dtype=dtypes.string)
|
||||
out_data = layer(inp_data)
|
||||
self.assertEqual(out_data.values.numpy().max(), 1)
|
||||
self.assertEqual(out_data.values.numpy().min(), 0)
|
||||
# hash of 'marlo' should be same.
|
||||
self.assertAllClose(out_data[0][2], out_data[1][0])
|
||||
# hash of 'wire' should be same.
|
||||
self.assertAllClose(out_data[0][3], out_data[1][2])
|
||||
# Same hashed output as test_hash_sparse_input_farmhash
|
||||
expected_output = [[0, 0, 1, 0], [1, 0, 0]]
|
||||
self.assertAllEqual(expected_output, out_data)
|
||||
|
||||
inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
||||
out_t = layer(inp_t)
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_hash_ragged_string_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
inp_data = ragged_factory_ops.constant(
|
||||
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||
dtype=dtypes.string)
|
||||
out_data = layer(inp_data)
|
||||
# Same hashed output as test_hash_dense_input_siphash
|
||||
expected_output = [[0, 1, 0, 1], [0, 0, 1]]
|
||||
self.assertAllEqual(expected_output, out_data)
|
||||
|
||||
inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
||||
out_t = layer(inp_t)
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
|
||||
out_data = layer_2(inp_data)
|
||||
expected_output = [[1, 0, 1, 0], [1, 1, 0]]
|
||||
self.assertAllEqual(expected_output, out_data)
|
||||
|
||||
out_t = layer_2(inp_t)
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_invalid_inputs(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'cannot be `None`'):
|
||||
_ = hashing.Hashing(num_bins=None)
|
||||
with self.assertRaisesRegexp(ValueError, 'cannot be `None`'):
|
||||
_ = hashing.Hashing(num_bins=-1)
|
||||
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||
_ = hashing.Hashing(num_bins=2, salt='string')
|
||||
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||
_ = hashing.Hashing(num_bins=2, salt=[1])
|
||||
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||
_ = hashing.Hashing(num_bins=1, salt=constant_op.constant([133, 137]))
|
||||
|
||||
def test_hash_compute_output_signature(self):
|
||||
input_shape = tensor_shape.TensorShape([2, 3])
|
||||
input_spec = tensor_spec.TensorSpec(input_shape, dtypes.string)
|
||||
|
Loading…
Reference in New Issue
Block a user