int support for Hashing layer.
PiperOrigin-RevId: 294828742 Change-Id: I4a6574fe9f84874dfb4e25ffa3d2572c779c1913
This commit is contained in:
parent
39cfd72c06
commit
68b23d5123
@ -93,7 +93,15 @@ class Hashing(Layer):
|
||||
self._supports_ragged_inputs = True
|
||||
|
||||
def call(self, inputs):
|
||||
# TODO(tanzheny): Add int support.
|
||||
# Converts integer inputs to string.
|
||||
if inputs.dtype.is_integer:
|
||||
if isinstance(inputs, sparse_tensor.SparseTensor):
|
||||
inputs = sparse_tensor.SparseTensor(
|
||||
indices=inputs.indices,
|
||||
values=string_ops.as_string(inputs.values),
|
||||
dense_shape=inputs.dense_shape)
|
||||
else:
|
||||
inputs = string_ops.as_string(inputs)
|
||||
str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
|
||||
if ragged_tensor.is_ragged(inputs):
|
||||
return ragged_functional_ops.map_flat_values(
|
||||
|
@ -51,6 +51,13 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[0], [0], [1], [0], [0]], output)
|
||||
|
||||
def test_hash_dense_int_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=3)
|
||||
inp = np.asarray([[0], [1], [2], [3], [4]])
|
||||
output = layer(inp)
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[1], [0], [1], [0], [2]], output)
|
||||
|
||||
def test_hash_dense_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
@ -65,6 +72,13 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
# Note the result is different from (133, 137).
|
||||
self.assertAllClose([[1], [0], [1], [0], [1]], output_2)
|
||||
|
||||
def test_hash_dense_int_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
|
||||
inp = np.asarray([[0], [1], [2], [3], [4]])
|
||||
output = layer(inp)
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[1], [1], [2], [0], [1]], output)
|
||||
|
||||
def test_hash_sparse_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
@ -76,6 +90,15 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
self.assertAllClose(indices, output.indices)
|
||||
self.assertAllClose([0, 0, 1, 0, 0], output.values)
|
||||
|
||||
def test_hash_sparse_int_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=3)
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
inp = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=[0, 1, 2, 3, 4], dense_shape=[3, 2])
|
||||
output = layer(inp)
|
||||
self.assertAllClose(indices, output.indices)
|
||||
self.assertAllClose([1, 0, 1, 0, 2], output.values)
|
||||
|
||||
def test_hash_sparse_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
@ -93,6 +116,15 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
# The result should be same with test_hash_dense_input_siphash.
|
||||
self.assertAllClose([1, 0, 1, 0, 1], output.values)
|
||||
|
||||
def test_hash_sparse_int_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
inp = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=[0, 1, 2, 3, 4], dense_shape=[3, 2])
|
||||
output = layer(inp)
|
||||
self.assertAllClose(indices, output.indices)
|
||||
self.assertAllClose([1, 1, 2, 0, 1], output.values)
|
||||
|
||||
def test_hash_ragged_string_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp_data = ragged_factory_ops.constant(
|
||||
@ -108,6 +140,20 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_hash_ragged_int_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=3)
|
||||
inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]],
|
||||
dtype=dtypes.int64)
|
||||
out_data = layer(inp_data)
|
||||
# Same hashed output as test_hash_sparse_input_farmhash
|
||||
expected_output = [[1, 0, 0, 2], [1, 0, 1]]
|
||||
self.assertAllEqual(expected_output, out_data)
|
||||
|
||||
inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.int64)
|
||||
out_t = layer(inp_t)
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_hash_ragged_string_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
inp_data = ragged_factory_ops.constant(
|
||||
@ -132,6 +178,20 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_hash_ragged_int_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
|
||||
inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]],
|
||||
dtype=dtypes.int64)
|
||||
out_data = layer(inp_data)
|
||||
# Same hashed output as test_hash_sparse_input_farmhash
|
||||
expected_output = [[1, 1, 0, 1], [2, 1, 1]]
|
||||
self.assertAllEqual(expected_output, out_data)
|
||||
|
||||
inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.int64)
|
||||
out_t = layer(inp_t)
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_invalid_inputs(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'cannot be `None`'):
|
||||
_ = hashing.Hashing(num_bins=None)
|
||||
|
Loading…
Reference in New Issue
Block a user