Support native python list for Hashing layer.
PiperOrigin-RevId: 315338165 Change-Id: I9d7859f63ebf748b6745b7d2c151ba1c8945a3c1
This commit is contained in:
parent
10ff91a6dc
commit
854e0bfdb2
tensorflow/python/keras/layers/preprocessing
@ -142,8 +142,8 @@ class CategoryCrossing(Layer):
|
||||
def _preprocess_input(self, inp):
|
||||
if isinstance(inp, (list, tuple, np.ndarray)):
|
||||
inp = ops.convert_to_tensor(inp)
|
||||
if inp.shape.rank == 1:
|
||||
inp = array_ops.expand_dims(inp, axis=-1)
|
||||
if inp.shape.rank == 1:
|
||||
inp = array_ops.expand_dims(inp, axis=-1)
|
||||
return inp
|
||||
|
||||
def call(self, inputs):
|
||||
|
@ -191,6 +191,11 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
output = layer([inputs_0, inputs_1])
|
||||
self.assertAllEqual([[b'1_X_1'], [b'2_X_3']], output)
|
||||
|
||||
inputs_0 = np.asarray([1, 2])
|
||||
inputs_1 = np.asarray([1, 3])
|
||||
output = layer([inputs_0, inputs_1])
|
||||
self.assertAllEqual([[b'1_X_1'], [b'2_X_3']], output)
|
||||
|
||||
def test_crossing_dense_inputs_depth_int(self):
|
||||
layer = category_crossing.CategoryCrossing(depth=1)
|
||||
inputs_0 = constant_op.constant([['a'], ['b'], ['c']])
|
||||
|
@ -19,11 +19,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.ops import gen_sparse_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
@ -58,7 +61,7 @@ class Hashing(Layer):
|
||||
Example (FarmHash64):
|
||||
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
|
||||
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
>>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]
|
||||
>>> layer(inp)
|
||||
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
|
||||
array([[1],
|
||||
@ -68,11 +71,24 @@ class Hashing(Layer):
|
||||
[2]])>
|
||||
|
||||
|
||||
Example (FarmHash64) with list of inputs:
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
|
||||
>>> inp_1 = [['A'], ['B'], ['C'], ['D'], ['E']]
|
||||
>>> inp_2 = np.asarray([[5], [4], [3], [2], [1]])
|
||||
>>> layer([inp_1, inp_2])
|
||||
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
|
||||
array([[1],
|
||||
[1],
|
||||
[0],
|
||||
[2],
|
||||
[0]])>
|
||||
|
||||
|
||||
Example (SipHash64):
|
||||
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
|
||||
... salt=[133, 137])
|
||||
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
>>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]
|
||||
>>> layer(inp)
|
||||
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
|
||||
array([[1],
|
||||
@ -85,7 +101,7 @@ class Hashing(Layer):
|
||||
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
|
||||
... salt=133)
|
||||
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
>>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]
|
||||
>>> layer(inp)
|
||||
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
|
||||
array([[0],
|
||||
@ -134,7 +150,23 @@ class Hashing(Layer):
|
||||
else:
|
||||
self.salt = _DEFAULT_SALT_KEY
|
||||
|
||||
def _preprocess_single_input(self, inp):
|
||||
if isinstance(inp, (list, tuple, np.ndarray)):
|
||||
inp = ops.convert_to_tensor(inp)
|
||||
return inp
|
||||
|
||||
def _preprocess_inputs(self, inputs):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
# If any of them is tensor or ndarray, then treat as list
|
||||
if any([
|
||||
tensor_util.is_tensor(inp) or isinstance(inp, np.ndarray)
|
||||
for inp in inputs
|
||||
]):
|
||||
return [self._preprocess_single_input(inp) for inp in inputs]
|
||||
return self._preprocess_single_input(inputs)
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = self._preprocess_inputs(inputs)
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
return self._process_input_list(inputs)
|
||||
else:
|
||||
|
@ -60,6 +60,27 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[0], [0], [1], [1], [0]], output)
|
||||
|
||||
def test_hash_dense_list_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp = [['omar'], ['stringer'], ['marlo'], ['wire'], ['skywalker']]
|
||||
output = layer(inp)
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[0], [0], [1], [0], [0]], output)
|
||||
|
||||
inp = ['omar', 'stringer', 'marlo', 'wire', 'skywalker']
|
||||
output = layer(inp)
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([0, 0, 1, 0, 0], output)
|
||||
|
||||
def test_hash_dense_list_inputs_mixed_int_string_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
['skywalker']])
|
||||
inp_2 = np.asarray([[1], [2], [3], [4], [5]]).astype(np.int64)
|
||||
output = layer([inp_1, inp_2])
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[0], [1], [1], [1], [0]], output)
|
||||
|
||||
def test_hash_dense_int_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=3)
|
||||
inp = np.asarray([[0], [1], [2], [3], [4]])
|
||||
|
Loading…
Reference in New Issue
Block a user