Let CategoryCrossing support native python list

PiperOrigin-RevId: 314857565
Change-Id: I118bc5212a1514c4f2614b73358872135444e4dd
This commit is contained in:
Zhenyu Tan 2020-06-04 20:48:51 -07:00 committed by TensorFlower Gardener
parent 2e311c6a9c
commit 0418b68c73
2 changed files with 26 additions and 4 deletions

View File

@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
@ -40,8 +42,8 @@ class CategoryCrossing(Layer):
output (similar to Cartesian product). The output dtype is string.
Usage:
>>> inp_1 = tf.constant([['a'], ['b'], ['c']])
>>> inp_2 = tf.constant([['d'], ['e'], ['f']])
>>> inp_1 = ['a', 'b', 'c']
>>> inp_2 = ['d', 'e', 'f']
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing()
>>> layer([inp_1, inp_2])
<tf.Tensor: shape=(3, 1), dtype=string, numpy=
@ -50,8 +52,8 @@ class CategoryCrossing(Layer):
[b'c_X_f']], dtype=object)>
>>> inp_1 = tf.constant([['a'], ['b'], ['c']])
>>> inp_2 = tf.constant([['d'], ['e'], ['f']])
>>> inp_1 = ['a', 'b', 'c']
>>> inp_2 = ['d', 'e', 'f']
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing(
... separator='-')
>>> layer([inp_1, inp_2])
@ -137,7 +139,15 @@ class CategoryCrossing(Layer):
return sparse_ops.sparse_tensor_to_dense(
sparse_ops.sparse_cross(partial_inputs, separator=self.separator))
def _preprocess_input(self, inp):
if isinstance(inp, (list, tuple, np.ndarray)):
inp = ops.convert_to_tensor(inp)
if inp.shape.rank == 1:
inp = array_ops.expand_dims(inp, axis=-1)
return inp
def call(self, inputs):
inputs = [self._preprocess_input(inp) for inp in inputs]
depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
ragged_out = sparse_out = False
if any(ragged_tensor.is_ragged(inp) for inp in inputs):

View File

@ -179,6 +179,18 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
output = layer([inputs_0, inputs_1])
self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output)
def test_crossing_with_list_inputs(self):
layer = category_crossing.CategoryCrossing()
inputs_0 = [[1, 2]]
inputs_1 = [[1, 3]]
output = layer([inputs_0, inputs_1])
self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output)
inputs_0 = [1, 2]
inputs_1 = [1, 3]
output = layer([inputs_0, inputs_1])
self.assertAllEqual([[b'1_X_1'], [b'2_X_3']], output)
def test_crossing_dense_inputs_depth_int(self):
layer = category_crossing.CategoryCrossing(depth=1)
inputs_0 = constant_op.constant([['a'], ['b'], ['c']])