Support native python list for Hashing layer.

PiperOrigin-RevId: 315338165
Change-Id: I9d7859f63ebf748b6745b7d2c151ba1c8945a3c1
This commit is contained in:
Zhenyu Tan 2020-06-08 13:14:20 -07:00 committed by TensorFlower Gardener
parent 10ff91a6dc
commit 854e0bfdb2
4 changed files with 63 additions and 5 deletions

View File

@ -142,8 +142,8 @@ class CategoryCrossing(Layer):
def _preprocess_input(self, inp):
if isinstance(inp, (list, tuple, np.ndarray)):
inp = ops.convert_to_tensor(inp)
if inp.shape.rank == 1:
inp = array_ops.expand_dims(inp, axis=-1)
if inp.shape.rank == 1:
inp = array_ops.expand_dims(inp, axis=-1)
return inp
def call(self, inputs):

View File

@ -191,6 +191,11 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
output = layer([inputs_0, inputs_1])
self.assertAllEqual([[b'1_X_1'], [b'2_X_3']], output)
inputs_0 = np.asarray([1, 2])
inputs_1 = np.asarray([1, 3])
output = layer([inputs_0, inputs_1])
self.assertAllEqual([[b'1_X_1'], [b'2_X_3']], output)
def test_crossing_dense_inputs_depth_int(self):
layer = category_crossing.CategoryCrossing(depth=1)
inputs_0 = constant_op.constant([['a'], ['b'], ['c']])

View File

@ -19,11 +19,14 @@ from __future__ import division
from __future__ import print_function
import functools
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import sparse_ops
@ -58,7 +61,7 @@ class Hashing(Layer):
Example (FarmHash64):
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
>>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]
>>> layer(inp)
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
array([[1],
@ -68,11 +71,24 @@ class Hashing(Layer):
[2]])>
Example (FarmHash64) with list of inputs:
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
>>> inp_1 = [['A'], ['B'], ['C'], ['D'], ['E']]
>>> inp_2 = np.asarray([[5], [4], [3], [2], [1]])
>>> layer([inp_1, inp_2])
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
array([[1],
[1],
[0],
[2],
[0]])>
Example (SipHash64):
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
... salt=[133, 137])
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
>>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]
>>> layer(inp)
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
array([[1],
@ -85,7 +101,7 @@ class Hashing(Layer):
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
... salt=133)
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
>>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]
>>> layer(inp)
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
array([[0],
@ -134,7 +150,23 @@ class Hashing(Layer):
else:
self.salt = _DEFAULT_SALT_KEY
def _preprocess_single_input(self, inp):
if isinstance(inp, (list, tuple, np.ndarray)):
inp = ops.convert_to_tensor(inp)
return inp
def _preprocess_inputs(self, inputs):
if isinstance(inputs, (tuple, list)):
# If any of them is tensor or ndarray, then treat as list
if any([
tensor_util.is_tensor(inp) or isinstance(inp, np.ndarray)
for inp in inputs
]):
return [self._preprocess_single_input(inp) for inp in inputs]
return self._preprocess_single_input(inputs)
def call(self, inputs):
inputs = self._preprocess_inputs(inputs)
if isinstance(inputs, (tuple, list)):
return self._process_input_list(inputs)
else:

View File

@ -60,6 +60,27 @@ class HashingTest(keras_parameterized.TestCase):
# Assert equal for hashed output that should be true on all platforms.
self.assertAllClose([[0], [0], [1], [1], [0]], output)
def test_hash_dense_list_input_farmhash(self):
layer = hashing.Hashing(num_bins=2)
inp = [['omar'], ['stringer'], ['marlo'], ['wire'], ['skywalker']]
output = layer(inp)
# Assert equal for hashed output that should be true on all platforms.
self.assertAllClose([[0], [0], [1], [0], [0]], output)
inp = ['omar', 'stringer', 'marlo', 'wire', 'skywalker']
output = layer(inp)
# Assert equal for hashed output that should be true on all platforms.
self.assertAllClose([0, 0, 1, 0, 0], output)
def test_hash_dense_list_inputs_mixed_int_string_farmhash(self):
layer = hashing.Hashing(num_bins=2)
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
['skywalker']])
inp_2 = np.asarray([[1], [2], [3], [4], [5]]).astype(np.int64)
output = layer([inp_1, inp_2])
# Assert equal for hashed output that should be true on all platforms.
self.assertAllClose([[0], [1], [1], [1], [0]], output)
def test_hash_dense_int_input_farmhash(self):
layer = hashing.Hashing(num_bins=3)
inp = np.asarray([[0], [1], [2], [3], [4]])