From 0d23b37aafcbbc21dcd075e05554a71bbdd36723 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Wed, 12 Feb 2020 20:10:26 -0800 Subject: [PATCH] Adding hashing trick for Hashing layer. PiperOrigin-RevId: 294821319 Change-Id: Id127f5df76311bc3904dd7f68c628b310cbc9e85 --- .../keras/layers/preprocessing/hashing.py | 73 ++++++++++---- .../layers/preprocessing/hashing_test.py | 99 ++++++++++++++++--- 2 files changed, 136 insertions(+), 36 deletions(-) diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py index 329b502ed95..d9183942adb 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec @@ -35,16 +37,39 @@ class Hashing(Layer): tensorflow::ops::Fingerprint to produce universal output that is consistent across platforms. - Usage: + This layer uses [FarmHash64](https://github.com/google/farmhash) by default, + which provides a consistent hashed output across different platforms and is + stable across invocations, regardless of device and context, by mixing the + input bits thoroughly. + + If you want to obfuscate the hashed output, you can also pass a random `salt` + argument in the constructor. In that case, the layer will use the + [SipHash64](https://github.com/google/highwayhash) hash function, with + the `salt` value serving as additional input to the hash function. + + Example (FarmHash64): ```python layer = Hashing(num_bins=3) - inp = np.asarray([['A', 'B'], ['C', 'A']]) + inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) layer(inputs) - [[0, 0], [1, 0]] + [[1], [0], [1], [1], [2]] + ``` + + Example (SipHash64): + ```python + layer = Hashing(num_bins=3, salt=[133, 137]) + inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + layer(inputs) + [[1], [2], [1], [0], [2]] ``` Arguments: num_bins: Number of hash bins. + salt: A tuple/list of 2 unsigned integer numbers. If passed, the hash + function used will be SipHash64, with these values used as an additional + input (known as a "salt" in cryptography). + These should be non-zero. Defaults to `None` (in that + case, the FarmHash64 hash function is used). name: Name to give to the layer. **kwargs: Keyword arguments to construct a layer. @@ -53,38 +78,46 @@ class Hashing(Layer): Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]` - Example: - If the input is a 5 by 1 string tensor '[['A'], ['B'], ['C'], ['D'], ['E']]' - with `num_bins=2`, then output is 5 by 1 integer tensor - [[hash('A')], [hash('B')], [hash('C')], [hash('D')], [hash('E')]]. """ - def __init__(self, num_bins, name=None, **kwargs): - # TODO(tanzheny): consider adding strong hash variant. + def __init__(self, num_bins, salt=None, name=None, **kwargs): + if num_bins is None or num_bins <= 0: + raise ValueError('`num_bins` cannot be `None` or non-positive values.') + if salt is not None: + if not isinstance(salt, (tuple, list)) or len(salt) != 2: + raise ValueError('`salt` must be a tuple or list of 2 unsigned ' + 'integer numbers, got {}'.format(salt)) super(Hashing, self).__init__(name=name, **kwargs) - self._num_bins = num_bins + self.num_bins = num_bins + self.salt = salt self._supports_ragged_inputs = True def call(self, inputs): # TODO(tanzheny): Add int support. - # string_to_hash_bucket_fast uses FarmHash as hash function. + str_to_hash_bucket = self._get_string_to_hash_bucket_fn() if ragged_tensor.is_ragged(inputs): return ragged_functional_ops.map_flat_values( - string_ops.string_to_hash_bucket_fast, - inputs, - num_buckets=self._num_bins, - name='hash') + str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash') elif isinstance(inputs, sparse_tensor.SparseTensor): sparse_values = inputs.values - sparse_hashed_values = string_ops.string_to_hash_bucket_fast( - sparse_values, self._num_bins, name='hash') + sparse_hashed_values = str_to_hash_bucket( + sparse_values, self.num_bins, name='hash') return sparse_tensor.SparseTensor( indices=inputs.indices, values=sparse_hashed_values, dense_shape=inputs.dense_shape) else: - return string_ops.string_to_hash_bucket_fast( - inputs, self._num_bins, name='hash') + return str_to_hash_bucket(inputs, self.num_bins, name='hash') + + def _get_string_to_hash_bucket_fn(self): + """Returns the string_to_hash_bucket op to use based on `hasher_key`.""" + # string_to_hash_bucket_fast uses FarmHash64 as hash function. + if self.salt is None: + return string_ops.string_to_hash_bucket_fast + # string_to_hash_bucket_strong uses SipHash64 as hash function. + else: + return functools.partial( + string_ops.string_to_hash_bucket_strong, key=self.salt) def compute_output_shape(self, input_shape): return input_shape @@ -99,6 +132,6 @@ class Hashing(Layer): return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype) def get_config(self): - config = {'num_bins': self._num_bins} + config = {'num_bins': self.num_bins, 'salt': self.salt} base_config = super(Hashing, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_test.py index a49f336a8e5..0fb5aa8a4bb 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape @@ -40,43 +41,109 @@ class HashingTest(keras_parameterized.TestCase): layer = hashing.Hashing(num_bins=1) inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) output = layer(inp) - self.assertAllClose(np.asarray([[0], [0], [0], [0], [0]]), output) + self.assertAllClose([[0], [0], [0], [0], [0]], output) - def test_hash_two_bins(self): + def test_hash_dense_input_farmhash(self): layer = hashing.Hashing(num_bins=2) - inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], + ['skywalker']]) output = layer(inp) - self.assertEqual(output.numpy().max(), 1) - self.assertEqual(output.numpy().min(), 0) + # Assert equal for hashed output that should be true on all platforms. + self.assertAllClose([[0], [0], [1], [0], [0]], output) - def test_hash_sparse_input(self): + def test_hash_dense_input_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], + ['skywalker']]) + output = layer(inp) + # Assert equal for hashed output that should be true on all platforms. + # Note the result is different from FarmHash. + self.assertAllClose([[0], [1], [0], [1], [0]], output) + + layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137]) + output_2 = layer_2(inp) + # Note the result is different from (133, 137). + self.assertAllClose([[1], [0], [1], [0], [1]], output_2) + + def test_hash_sparse_input_farmhash(self): layer = hashing.Hashing(num_bins=2) + indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] inp = sparse_tensor.SparseTensor( - indices=[[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]], + indices=indices, values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'], dense_shape=[3, 2]) output = layer(inp) - self.assertEqual(output.values.numpy().max(), 1) - self.assertEqual(output.values.numpy().min(), 0) + self.assertAllClose(indices, output.indices) + self.assertAllClose([0, 0, 1, 0, 0], output.values) - def test_hash_ragged_string_input(self): + def test_hash_sparse_input_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] + inp = sparse_tensor.SparseTensor( + indices=indices, + values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'], + dense_shape=[3, 2]) + output = layer(inp) + self.assertAllClose(output.indices, indices) + # The result should be same with test_hash_dense_input_siphash. + self.assertAllClose([0, 1, 0, 1, 0], output.values) + + layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137]) + output = layer_2(inp) + # The result should be same with test_hash_dense_input_siphash. + self.assertAllClose([1, 0, 1, 0, 1], output.values) + + def test_hash_ragged_string_input_farmhash(self): layer = hashing.Hashing(num_bins=2) inp_data = ragged_factory_ops.constant( [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], dtype=dtypes.string) out_data = layer(inp_data) - self.assertEqual(out_data.values.numpy().max(), 1) - self.assertEqual(out_data.values.numpy().min(), 0) - # hash of 'marlo' should be same. - self.assertAllClose(out_data[0][2], out_data[1][0]) - # hash of 'wire' should be same. - self.assertAllClose(out_data[0][3], out_data[1][2]) + # Same hashed output as test_hash_sparse_input_farmhash + expected_output = [[0, 0, 1, 0], [1, 0, 0]] + self.assertAllEqual(expected_output, out_data) inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) out_t = layer(inp_t) model = training.Model(inputs=inp_t, outputs=out_t) self.assertAllClose(out_data, model.predict(inp_data)) + def test_hash_ragged_string_input_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + inp_data = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + out_data = layer(inp_data) + # Same hashed output as test_hash_dense_input_siphash + expected_output = [[0, 1, 0, 1], [0, 0, 1]] + self.assertAllEqual(expected_output, out_data) + + inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) + out_t = layer(inp_t) + model = training.Model(inputs=inp_t, outputs=out_t) + self.assertAllClose(out_data, model.predict(inp_data)) + + layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137]) + out_data = layer_2(inp_data) + expected_output = [[1, 0, 1, 0], [1, 1, 0]] + self.assertAllEqual(expected_output, out_data) + + out_t = layer_2(inp_t) + model = training.Model(inputs=inp_t, outputs=out_t) + self.assertAllClose(out_data, model.predict(inp_data)) + + def test_invalid_inputs(self): + with self.assertRaisesRegexp(ValueError, 'cannot be `None`'): + _ = hashing.Hashing(num_bins=None) + with self.assertRaisesRegexp(ValueError, 'cannot be `None`'): + _ = hashing.Hashing(num_bins=-1) + with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + _ = hashing.Hashing(num_bins=2, salt='string') + with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + _ = hashing.Hashing(num_bins=2, salt=[1]) + with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + _ = hashing.Hashing(num_bins=1, salt=constant_op.constant([133, 137])) + def test_hash_compute_output_signature(self): input_shape = tensor_shape.TensorShape([2, 3]) input_spec = tensor_spec.TensorSpec(input_shape, dtypes.string)