Adding hashing trick for Hashing layer.
PiperOrigin-RevId: 294821319 Change-Id: Id127f5df76311bc3904dd7f68c628b310cbc9e85
This commit is contained in:
parent
f6a8bcd0f2
commit
0d23b37aaf
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
@ -35,16 +37,39 @@ class Hashing(Layer):
|
|||||||
tensorflow::ops::Fingerprint to produce universal output that is consistent
|
tensorflow::ops::Fingerprint to produce universal output that is consistent
|
||||||
across platforms.
|
across platforms.
|
||||||
|
|
||||||
Usage:
|
This layer uses [FarmHash64](https://github.com/google/farmhash) by default,
|
||||||
|
which provides a consistent hashed output across different platforms and is
|
||||||
|
stable across invocations, regardless of device and context, by mixing the
|
||||||
|
input bits thoroughly.
|
||||||
|
|
||||||
|
If you want to obfuscate the hashed output, you can also pass a random `salt`
|
||||||
|
argument in the constructor. In that case, the layer will use the
|
||||||
|
[SipHash64](https://github.com/google/highwayhash) hash function, with
|
||||||
|
the `salt` value serving as additional input to the hash function.
|
||||||
|
|
||||||
|
Example (FarmHash64):
|
||||||
```python
|
```python
|
||||||
layer = Hashing(num_bins=3)
|
layer = Hashing(num_bins=3)
|
||||||
inp = np.asarray([['A', 'B'], ['C', 'A']])
|
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||||
layer(inputs)
|
layer(inputs)
|
||||||
[[0, 0], [1, 0]]
|
[[1], [0], [1], [1], [2]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Example (SipHash64):
|
||||||
|
```python
|
||||||
|
layer = Hashing(num_bins=3, salt=[133, 137])
|
||||||
|
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||||
|
layer(inputs)
|
||||||
|
[[1], [2], [1], [0], [2]]
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
num_bins: Number of hash bins.
|
num_bins: Number of hash bins.
|
||||||
|
salt: A tuple/list of 2 unsigned integer numbers. If passed, the hash
|
||||||
|
function used will be SipHash64, with these values used as an additional
|
||||||
|
input (known as a "salt" in cryptography).
|
||||||
|
These should be non-zero. Defaults to `None` (in that
|
||||||
|
case, the FarmHash64 hash function is used).
|
||||||
name: Name to give to the layer.
|
name: Name to give to the layer.
|
||||||
**kwargs: Keyword arguments to construct a layer.
|
**kwargs: Keyword arguments to construct a layer.
|
||||||
|
|
||||||
@ -53,38 +78,46 @@ class Hashing(Layer):
|
|||||||
|
|
||||||
Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]`
|
Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]`
|
||||||
|
|
||||||
Example:
|
|
||||||
If the input is a 5 by 1 string tensor '[['A'], ['B'], ['C'], ['D'], ['E']]'
|
|
||||||
with `num_bins=2`, then output is 5 by 1 integer tensor
|
|
||||||
[[hash('A')], [hash('B')], [hash('C')], [hash('D')], [hash('E')]].
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_bins, name=None, **kwargs):
|
def __init__(self, num_bins, salt=None, name=None, **kwargs):
|
||||||
# TODO(tanzheny): consider adding strong hash variant.
|
if num_bins is None or num_bins <= 0:
|
||||||
|
raise ValueError('`num_bins` cannot be `None` or non-positive values.')
|
||||||
|
if salt is not None:
|
||||||
|
if not isinstance(salt, (tuple, list)) or len(salt) != 2:
|
||||||
|
raise ValueError('`salt` must be a tuple or list of 2 unsigned '
|
||||||
|
'integer numbers, got {}'.format(salt))
|
||||||
super(Hashing, self).__init__(name=name, **kwargs)
|
super(Hashing, self).__init__(name=name, **kwargs)
|
||||||
self._num_bins = num_bins
|
self.num_bins = num_bins
|
||||||
|
self.salt = salt
|
||||||
self._supports_ragged_inputs = True
|
self._supports_ragged_inputs = True
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
# TODO(tanzheny): Add int support.
|
# TODO(tanzheny): Add int support.
|
||||||
# string_to_hash_bucket_fast uses FarmHash as hash function.
|
str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
|
||||||
if ragged_tensor.is_ragged(inputs):
|
if ragged_tensor.is_ragged(inputs):
|
||||||
return ragged_functional_ops.map_flat_values(
|
return ragged_functional_ops.map_flat_values(
|
||||||
string_ops.string_to_hash_bucket_fast,
|
str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash')
|
||||||
inputs,
|
|
||||||
num_buckets=self._num_bins,
|
|
||||||
name='hash')
|
|
||||||
elif isinstance(inputs, sparse_tensor.SparseTensor):
|
elif isinstance(inputs, sparse_tensor.SparseTensor):
|
||||||
sparse_values = inputs.values
|
sparse_values = inputs.values
|
||||||
sparse_hashed_values = string_ops.string_to_hash_bucket_fast(
|
sparse_hashed_values = str_to_hash_bucket(
|
||||||
sparse_values, self._num_bins, name='hash')
|
sparse_values, self.num_bins, name='hash')
|
||||||
return sparse_tensor.SparseTensor(
|
return sparse_tensor.SparseTensor(
|
||||||
indices=inputs.indices,
|
indices=inputs.indices,
|
||||||
values=sparse_hashed_values,
|
values=sparse_hashed_values,
|
||||||
dense_shape=inputs.dense_shape)
|
dense_shape=inputs.dense_shape)
|
||||||
else:
|
else:
|
||||||
return string_ops.string_to_hash_bucket_fast(
|
return str_to_hash_bucket(inputs, self.num_bins, name='hash')
|
||||||
inputs, self._num_bins, name='hash')
|
|
||||||
|
def _get_string_to_hash_bucket_fn(self):
|
||||||
|
"""Returns the string_to_hash_bucket op to use based on `hasher_key`."""
|
||||||
|
# string_to_hash_bucket_fast uses FarmHash64 as hash function.
|
||||||
|
if self.salt is None:
|
||||||
|
return string_ops.string_to_hash_bucket_fast
|
||||||
|
# string_to_hash_bucket_strong uses SipHash64 as hash function.
|
||||||
|
else:
|
||||||
|
return functools.partial(
|
||||||
|
string_ops.string_to_hash_bucket_strong, key=self.salt)
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
return input_shape
|
return input_shape
|
||||||
@ -99,6 +132,6 @@ class Hashing(Layer):
|
|||||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {'num_bins': self._num_bins}
|
config = {'num_bins': self.num_bins, 'salt': self.salt}
|
||||||
base_config = super(Hashing, self).get_config()
|
base_config = super(Hashing, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -40,43 +41,109 @@ class HashingTest(keras_parameterized.TestCase):
|
|||||||
layer = hashing.Hashing(num_bins=1)
|
layer = hashing.Hashing(num_bins=1)
|
||||||
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||||
output = layer(inp)
|
output = layer(inp)
|
||||||
self.assertAllClose(np.asarray([[0], [0], [0], [0], [0]]), output)
|
self.assertAllClose([[0], [0], [0], [0], [0]], output)
|
||||||
|
|
||||||
def test_hash_two_bins(self):
|
def test_hash_dense_input_farmhash(self):
|
||||||
layer = hashing.Hashing(num_bins=2)
|
layer = hashing.Hashing(num_bins=2)
|
||||||
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||||
|
['skywalker']])
|
||||||
output = layer(inp)
|
output = layer(inp)
|
||||||
self.assertEqual(output.numpy().max(), 1)
|
# Assert equal for hashed output that should be true on all platforms.
|
||||||
self.assertEqual(output.numpy().min(), 0)
|
self.assertAllClose([[0], [0], [1], [0], [0]], output)
|
||||||
|
|
||||||
def test_hash_sparse_input(self):
|
def test_hash_dense_input_siphash(self):
|
||||||
|
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||||
|
inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||||
|
['skywalker']])
|
||||||
|
output = layer(inp)
|
||||||
|
# Assert equal for hashed output that should be true on all platforms.
|
||||||
|
# Note the result is different from FarmHash.
|
||||||
|
self.assertAllClose([[0], [1], [0], [1], [0]], output)
|
||||||
|
|
||||||
|
layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
|
||||||
|
output_2 = layer_2(inp)
|
||||||
|
# Note the result is different from (133, 137).
|
||||||
|
self.assertAllClose([[1], [0], [1], [0], [1]], output_2)
|
||||||
|
|
||||||
|
def test_hash_sparse_input_farmhash(self):
|
||||||
layer = hashing.Hashing(num_bins=2)
|
layer = hashing.Hashing(num_bins=2)
|
||||||
|
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||||
inp = sparse_tensor.SparseTensor(
|
inp = sparse_tensor.SparseTensor(
|
||||||
indices=[[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]],
|
indices=indices,
|
||||||
values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
|
values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
|
||||||
dense_shape=[3, 2])
|
dense_shape=[3, 2])
|
||||||
output = layer(inp)
|
output = layer(inp)
|
||||||
self.assertEqual(output.values.numpy().max(), 1)
|
self.assertAllClose(indices, output.indices)
|
||||||
self.assertEqual(output.values.numpy().min(), 0)
|
self.assertAllClose([0, 0, 1, 0, 0], output.values)
|
||||||
|
|
||||||
def test_hash_ragged_string_input(self):
|
def test_hash_sparse_input_siphash(self):
|
||||||
|
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||||
|
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||||
|
inp = sparse_tensor.SparseTensor(
|
||||||
|
indices=indices,
|
||||||
|
values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
|
||||||
|
dense_shape=[3, 2])
|
||||||
|
output = layer(inp)
|
||||||
|
self.assertAllClose(output.indices, indices)
|
||||||
|
# The result should be same with test_hash_dense_input_siphash.
|
||||||
|
self.assertAllClose([0, 1, 0, 1, 0], output.values)
|
||||||
|
|
||||||
|
layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
|
||||||
|
output = layer_2(inp)
|
||||||
|
# The result should be same with test_hash_dense_input_siphash.
|
||||||
|
self.assertAllClose([1, 0, 1, 0, 1], output.values)
|
||||||
|
|
||||||
|
def test_hash_ragged_string_input_farmhash(self):
|
||||||
layer = hashing.Hashing(num_bins=2)
|
layer = hashing.Hashing(num_bins=2)
|
||||||
inp_data = ragged_factory_ops.constant(
|
inp_data = ragged_factory_ops.constant(
|
||||||
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||||
dtype=dtypes.string)
|
dtype=dtypes.string)
|
||||||
out_data = layer(inp_data)
|
out_data = layer(inp_data)
|
||||||
self.assertEqual(out_data.values.numpy().max(), 1)
|
# Same hashed output as test_hash_sparse_input_farmhash
|
||||||
self.assertEqual(out_data.values.numpy().min(), 0)
|
expected_output = [[0, 0, 1, 0], [1, 0, 0]]
|
||||||
# hash of 'marlo' should be same.
|
self.assertAllEqual(expected_output, out_data)
|
||||||
self.assertAllClose(out_data[0][2], out_data[1][0])
|
|
||||||
# hash of 'wire' should be same.
|
|
||||||
self.assertAllClose(out_data[0][3], out_data[1][2])
|
|
||||||
|
|
||||||
inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
||||||
out_t = layer(inp_t)
|
out_t = layer(inp_t)
|
||||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||||
self.assertAllClose(out_data, model.predict(inp_data))
|
self.assertAllClose(out_data, model.predict(inp_data))
|
||||||
|
|
||||||
|
def test_hash_ragged_string_input_siphash(self):
|
||||||
|
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||||
|
inp_data = ragged_factory_ops.constant(
|
||||||
|
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||||
|
dtype=dtypes.string)
|
||||||
|
out_data = layer(inp_data)
|
||||||
|
# Same hashed output as test_hash_dense_input_siphash
|
||||||
|
expected_output = [[0, 1, 0, 1], [0, 0, 1]]
|
||||||
|
self.assertAllEqual(expected_output, out_data)
|
||||||
|
|
||||||
|
inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
||||||
|
out_t = layer(inp_t)
|
||||||
|
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||||
|
self.assertAllClose(out_data, model.predict(inp_data))
|
||||||
|
|
||||||
|
layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
|
||||||
|
out_data = layer_2(inp_data)
|
||||||
|
expected_output = [[1, 0, 1, 0], [1, 1, 0]]
|
||||||
|
self.assertAllEqual(expected_output, out_data)
|
||||||
|
|
||||||
|
out_t = layer_2(inp_t)
|
||||||
|
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||||
|
self.assertAllClose(out_data, model.predict(inp_data))
|
||||||
|
|
||||||
|
def test_invalid_inputs(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'cannot be `None`'):
|
||||||
|
_ = hashing.Hashing(num_bins=None)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'cannot be `None`'):
|
||||||
|
_ = hashing.Hashing(num_bins=-1)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||||
|
_ = hashing.Hashing(num_bins=2, salt='string')
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||||
|
_ = hashing.Hashing(num_bins=2, salt=[1])
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||||
|
_ = hashing.Hashing(num_bins=1, salt=constant_op.constant([133, 137]))
|
||||||
|
|
||||||
def test_hash_compute_output_signature(self):
|
def test_hash_compute_output_signature(self):
|
||||||
input_shape = tensor_shape.TensorShape([2, 3])
|
input_shape = tensor_shape.TensorShape([2, 3])
|
||||||
input_spec = tensor_spec.TensorSpec(input_shape, dtypes.string)
|
input_spec = tensor_spec.TensorSpec(input_shape, dtypes.string)
|
||||||
|
Loading…
Reference in New Issue
Block a user