Let CategoryCrossing support native python list
PiperOrigin-RevId: 314857565 Change-Id: I118bc5212a1514c4f2614b73358872135444e4dd
This commit is contained in:
parent
2e311c6a9c
commit
0418b68c73
@ -19,8 +19,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -40,8 +42,8 @@ class CategoryCrossing(Layer):
|
||||
output (similar to Cartesian product). The output dtype is string.
|
||||
|
||||
Usage:
|
||||
>>> inp_1 = tf.constant([['a'], ['b'], ['c']])
|
||||
>>> inp_2 = tf.constant([['d'], ['e'], ['f']])
|
||||
>>> inp_1 = ['a', 'b', 'c']
|
||||
>>> inp_2 = ['d', 'e', 'f']
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing()
|
||||
>>> layer([inp_1, inp_2])
|
||||
<tf.Tensor: shape=(3, 1), dtype=string, numpy=
|
||||
@ -50,8 +52,8 @@ class CategoryCrossing(Layer):
|
||||
[b'c_X_f']], dtype=object)>
|
||||
|
||||
|
||||
>>> inp_1 = tf.constant([['a'], ['b'], ['c']])
|
||||
>>> inp_2 = tf.constant([['d'], ['e'], ['f']])
|
||||
>>> inp_1 = ['a', 'b', 'c']
|
||||
>>> inp_2 = ['d', 'e', 'f']
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing(
|
||||
... separator='-')
|
||||
>>> layer([inp_1, inp_2])
|
||||
@ -137,7 +139,15 @@ class CategoryCrossing(Layer):
|
||||
return sparse_ops.sparse_tensor_to_dense(
|
||||
sparse_ops.sparse_cross(partial_inputs, separator=self.separator))
|
||||
|
||||
def _preprocess_input(self, inp):
|
||||
if isinstance(inp, (list, tuple, np.ndarray)):
|
||||
inp = ops.convert_to_tensor(inp)
|
||||
if inp.shape.rank == 1:
|
||||
inp = array_ops.expand_dims(inp, axis=-1)
|
||||
return inp
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = [self._preprocess_input(inp) for inp in inputs]
|
||||
depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
|
||||
ragged_out = sparse_out = False
|
||||
if any(ragged_tensor.is_ragged(inp) for inp in inputs):
|
||||
|
@ -179,6 +179,18 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
output = layer([inputs_0, inputs_1])
|
||||
self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output)
|
||||
|
||||
def test_crossing_with_list_inputs(self):
|
||||
layer = category_crossing.CategoryCrossing()
|
||||
inputs_0 = [[1, 2]]
|
||||
inputs_1 = [[1, 3]]
|
||||
output = layer([inputs_0, inputs_1])
|
||||
self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output)
|
||||
|
||||
inputs_0 = [1, 2]
|
||||
inputs_1 = [1, 3]
|
||||
output = layer([inputs_0, inputs_1])
|
||||
self.assertAllEqual([[b'1_X_1'], [b'2_X_3']], output)
|
||||
|
||||
def test_crossing_dense_inputs_depth_int(self):
|
||||
layer = category_crossing.CategoryCrossing(depth=1)
|
||||
inputs_0 = constant_op.constant([['a'], ['b'], ['c']])
|
||||
|
Loading…
Reference in New Issue
Block a user