diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD index d13473aea3f..50461573f84 100644 --- a/tensorflow/python/keras/layers/preprocessing/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/BUILD @@ -97,8 +97,10 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", @@ -106,8 +108,6 @@ py_library( "//tensorflow/python:tensor_spec", "//tensorflow/python:tensor_util", "//tensorflow/python/keras/engine", - "//tensorflow/python/keras/utils:tf_utils", - "//tensorflow/python/ops/ragged:ragged_functional_ops", "//tensorflow/python/ops/ragged:ragged_tensor", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py index cef41838194..925e1caa73d 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing.py @@ -28,11 +28,11 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras.engine import base_preprocessing_layer -from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops -from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util.tf_export import keras_export @@ -71,6 +71,19 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer): [1], [2]])> + Example (FarmHash64) with a mask value: + + >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3, + ... mask_value='') + >>> inp = [['A'], ['B'], [''], ['C'], ['D']] + >>> layer(inp) + + Example (FarmHash64) with list of inputs: >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3) @@ -114,7 +127,12 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer): Reference: [SipHash with salt](https://www.131002.net/siphash/siphash.pdf) Args: - num_bins: Number of hash bins. + num_bins: Number of hash bins. Note that this includes the `mask_value` bin, + so the effective number of bins is `(num_bins - 1)` if `mask_value` is + set. + mask_value: A value that represents masked inputs, which are mapped to + index 0. Defaults to None, meaning no mask term will be added and the + hashing will start at index 0. salt: A single unsigned integer or None. If passed, the hash function used will be SipHash64, with these values used as an additional input (known as a "salt" in cryptography). @@ -134,12 +152,13 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer): """ - def __init__(self, num_bins, salt=None, name=None, **kwargs): + def __init__(self, num_bins, mask_value=None, salt=None, name=None, **kwargs): if num_bins is None or num_bins <= 0: raise ValueError('`num_bins` cannot be `None` or non-positive values.') super(Hashing, self).__init__(name=name, **kwargs) base_preprocessing_layer.keras_kpl_gauge.get_cell('Hashing').set(True) self.num_bins = num_bins + self.mask_value = mask_value self.strong_hash = True if salt is not None else False if salt is not None: if isinstance(salt, (tuple, list)) and len(salt) == 2: @@ -170,39 +189,22 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer): inputs = self._preprocess_inputs(inputs) if isinstance(inputs, (tuple, list)): return self._process_input_list(inputs) - else: - return self._process_single_input(inputs) - - def _process_single_input(self, inputs): - # Converts integer inputs to string. - if inputs.dtype.is_integer: - if isinstance(inputs, sparse_tensor.SparseTensor): - inputs = sparse_tensor.SparseTensor( - indices=inputs.indices, - values=string_ops.as_string(inputs.values), - dense_shape=inputs.dense_shape) - else: - inputs = string_ops.as_string(inputs) - str_to_hash_bucket = self._get_string_to_hash_bucket_fn() - if tf_utils.is_ragged(inputs): - return ragged_functional_ops.map_flat_values( - str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash') elif isinstance(inputs, sparse_tensor.SparseTensor): - sparse_values = inputs.values - sparse_hashed_values = str_to_hash_bucket( - sparse_values, self.num_bins, name='hash') return sparse_tensor.SparseTensor( indices=inputs.indices, - values=sparse_hashed_values, + values=self._hash_values_to_bins(inputs.values), dense_shape=inputs.dense_shape) - else: - return str_to_hash_bucket(inputs, self.num_bins, name='hash') + return self._hash_values_to_bins(inputs) def _process_input_list(self, inputs): # TODO(momernick): support ragged_cross_hashed with corrected fingerprint # and siphash. if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs): raise ValueError('Hashing with ragged input is not supported yet.') + if self.mask_value is not None: + raise ValueError( + 'Cross hashing with a mask_value is not supported yet, mask_value is ' + '{}.'.format(self.mask_value)) sparse_inputs = [ inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor) ] @@ -226,6 +228,24 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer): return sparse_ops.sparse_tensor_to_dense(sparse_out) return sparse_out + def _hash_values_to_bins(self, values): + """Converts a non-sparse tensor of values to bin indices.""" + str_to_hash_bucket = self._get_string_to_hash_bucket_fn() + num_available_bins = self.num_bins + mask = None + # If mask_value is set, the zeroth bin is reserved for it. + if self.mask_value is not None and num_available_bins > 1: + num_available_bins -= 1 + mask = math_ops.equal(values, self.mask_value) + # Convert all values to strings before hashing. + if values.dtype.is_integer: + values = string_ops.as_string(values) + values = str_to_hash_bucket(values, num_available_bins, name='hash') + if mask is not None: + values = math_ops.add(values, array_ops.ones_like(values)) + values = array_ops.where(mask, array_ops.zeros_like(values), values) + return values + def _get_string_to_hash_bucket_fn(self): """Returns the string_to_hash_bucket op to use based on `hasher_key`.""" # string_to_hash_bucket_fast uses FarmHash64 as hash function. @@ -274,6 +294,10 @@ class Hashing(base_preprocessing_layer.PreprocessingLayer): return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64) def get_config(self): - config = {'num_bins': self.num_bins, 'salt': self.salt} + config = { + 'num_bins': self.num_bins, + 'salt': self.salt, + 'mask_value': self.mask_value, + } base_config = super(Hashing, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_test.py index 58592b8910a..712a78eca9e 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing_test.py @@ -51,6 +51,27 @@ class HashingTest(keras_parameterized.TestCase): # Assert equal for hashed output that should be true on all platforms. self.assertAllClose([[0], [0], [1], [0], [0]], output) + def test_hash_dense_input_mask_value_farmhash(self): + empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='') + omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar') + inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], + ['skywalker']]) + empty_mask_output = empty_mask_layer(inp) + omar_mask_output = omar_mask_layer(inp) + # Outputs should be one more than test_hash_dense_input_farmhash (the zeroth + # bin is now reserved for masks). + self.assertAllClose([[1], [1], [2], [1], [1]], empty_mask_output) + # 'omar' should map to 0. + self.assertAllClose([[0], [1], [2], [1], [1]], omar_mask_output) + + def test_hash_dense_multi_inputs_mask_value_farmhash(self): + layer = hashing.Hashing(num_bins=3, mask_value='omar') + inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], + ['skywalker']]) + inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + with self.assertRaisesRegex(ValueError, 'not supported yet'): + _ = layer([inp_1, inp_2]) + def test_hash_dense_multi_inputs_farmhash(self): layer = hashing.Hashing(num_bins=2) inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], @@ -135,6 +156,24 @@ class HashingTest(keras_parameterized.TestCase): self.assertAllClose(indices, output.indices) self.assertAllClose([0, 0, 1, 0, 0], output.values) + def test_hash_sparse_input_mask_value_farmhash(self): + empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='') + omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar') + indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] + inp = sparse_tensor.SparseTensor( + indices=indices, + values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'], + dense_shape=[3, 2]) + empty_mask_output = empty_mask_layer(inp) + omar_mask_output = omar_mask_layer(inp) + self.assertAllClose(indices, omar_mask_output.indices) + self.assertAllClose(indices, empty_mask_output.indices) + # Outputs should be one more than test_hash_sparse_input_farmhash (the + # zeroth bin is now reserved for masks). + self.assertAllClose([1, 1, 2, 1, 1], empty_mask_output.values) + # 'omar' should map to 0. + self.assertAllClose([0, 1, 2, 1, 1], omar_mask_output.values) + def test_hash_sparse_multi_inputs_farmhash(self): layer = hashing.Hashing(num_bins=2) indices = [[0, 0], [1, 0], [2, 0]] @@ -217,6 +256,22 @@ class HashingTest(keras_parameterized.TestCase): model = training.Model(inputs=inp_t, outputs=out_t) self.assertAllClose(out_data, model.predict(inp_data)) + def test_hash_ragged_input_mask_value(self): + empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='') + omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar') + inp_data = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + empty_mask_output = empty_mask_layer(inp_data) + omar_mask_output = omar_mask_layer(inp_data) + # Outputs should be one more than test_hash_ragged_string_input_farmhash + # (the zeroth bin is now reserved for masks). + expected_output = [[1, 1, 2, 1], [2, 1, 1]] + self.assertAllClose(expected_output, empty_mask_output) + # 'omar' should map to 0. + expected_output = [[0, 1, 2, 1], [2, 1, 1]] + self.assertAllClose(expected_output, omar_mask_output) + def test_hash_ragged_string_multi_inputs_farmhash(self): layer = hashing.Hashing(num_bins=2) inp_data_1 = ragged_factory_ops.constant( diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt index 2c9af8c3b32..dbd8d4fe10a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt @@ -130,7 +130,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "adapt" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt index 2c9af8c3b32..dbd8d4fe10a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt @@ -130,7 +130,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "adapt"