From 0418b68c730d63a075c6b3f754d15be8af8eeb5a Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Thu, 4 Jun 2020 20:48:51 -0700 Subject: [PATCH] Let CategoryCrossing support native python list PiperOrigin-RevId: 314857565 Change-Id: I118bc5212a1514c4f2614b73358872135444e4dd --- .../layers/preprocessing/category_crossing.py | 18 ++++++++++++++---- .../preprocessing/category_crossing_test.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/layers/preprocessing/category_crossing.py b/tensorflow/python/keras/layers/preprocessing/category_crossing.py index 84e5332bea5..fa0237595ac 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_crossing.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing.py @@ -19,8 +19,10 @@ from __future__ import division from __future__ import print_function import itertools +import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec @@ -40,8 +42,8 @@ class CategoryCrossing(Layer): output (similar to Cartesian product). The output dtype is string. Usage: - >>> inp_1 = tf.constant([['a'], ['b'], ['c']]) - >>> inp_2 = tf.constant([['d'], ['e'], ['f']]) + >>> inp_1 = ['a', 'b', 'c'] + >>> inp_2 = ['d', 'e', 'f'] >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing() >>> layer([inp_1, inp_2]) - >>> inp_1 = tf.constant([['a'], ['b'], ['c']]) - >>> inp_2 = tf.constant([['d'], ['e'], ['f']]) + >>> inp_1 = ['a', 'b', 'c'] + >>> inp_2 = ['d', 'e', 'f'] >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing( ... separator='-') >>> layer([inp_1, inp_2]) @@ -137,7 +139,15 @@ class CategoryCrossing(Layer): return sparse_ops.sparse_tensor_to_dense( sparse_ops.sparse_cross(partial_inputs, separator=self.separator)) + def _preprocess_input(self, inp): + if isinstance(inp, (list, tuple, np.ndarray)): + inp = ops.convert_to_tensor(inp) + if inp.shape.rank == 1: + inp = array_ops.expand_dims(inp, axis=-1) + return inp + def call(self, inputs): + inputs = [self._preprocess_input(inp) for inp in inputs] depth_tuple = self._depth_tuple if self.depth else (len(inputs),) ragged_out = sparse_out = False if any(ragged_tensor.is_ragged(inp) for inp in inputs): diff --git a/tensorflow/python/keras/layers/preprocessing/category_crossing_test.py b/tensorflow/python/keras/layers/preprocessing/category_crossing_test.py index f076c9ea865..83e78c4dd46 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_crossing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing_test.py @@ -179,6 +179,18 @@ class CategoryCrossingTest(keras_parameterized.TestCase): output = layer([inputs_0, inputs_1]) self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output) + def test_crossing_with_list_inputs(self): + layer = category_crossing.CategoryCrossing() + inputs_0 = [[1, 2]] + inputs_1 = [[1, 3]] + output = layer([inputs_0, inputs_1]) + self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output) + + inputs_0 = [1, 2] + inputs_1 = [1, 3] + output = layer([inputs_0, inputs_1]) + self.assertAllEqual([[b'1_X_1'], [b'2_X_3']], output) + def test_crossing_dense_inputs_depth_int(self): layer = category_crossing.CategoryCrossing(depth=1) inputs_0 = constant_op.constant([['a'], ['b'], ['c']])