Fixup output shape for IntegerLookup/StringLookup layers

This makes the following fixes for BINARY and COUNT output
 - Fixes compute_output_shape and compute_output_signature
 - Properly propogates batch shape for dense inputs
 - Adds test coverage

PiperOrigin-RevId: 355071871
Change-Id: I7820763100b643b8cd12908caf416aae1c4a1f14
This commit is contained in:
Matt Watson 2021-02-01 18:33:46 -08:00 committed by TensorFlower Gardener
parent 229cbce4ca
commit 5031630885
4 changed files with 79 additions and 11 deletions

View File

@ -160,6 +160,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:util",
"//tensorflow/python/keras:backend",
"//tensorflow/python/keras/engine",
"//third_party/py/numpy",
],

View File

@ -534,5 +534,6 @@ def dense_bincount(inputs, out_depth, binary_output, count_weights=None):
dtype=K.floatx(),
axis=-1,
binary_output=binary_output)
result.set_shape(tensor_shape.TensorShape((None, out_depth)))
batch_size = inputs.shape.as_list()[0]
result.set_shape(tensor_shape.TensorShape((batch_size, out_depth)))
return result

View File

@ -27,6 +27,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.keras.layers.preprocessing import table_utils
@ -160,22 +161,20 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
super(IndexLookup, self).__init__(
combiner=_IndexLookupCombiner(vocab_size, self.mask_token), **kwargs)
self._output_dtype = dtypes.int64
# We need to save the key dtype so that we know if we're expecting int64
# keys. If we are, we will cast int32 inputs to int64 as well.
if invert:
self._key_dtype = self._output_dtype
value_dtype = self.dtype
self._key_dtype = dtypes.int64
self._value_dtype = self.dtype
oov_value = self.oov_token
else:
self._key_dtype = self.dtype
value_dtype = self._output_dtype
self._value_dtype = dtypes.int64
oov_value = self._oov_value
self._table = lookup_ops.MutableHashTable(
key_dtype=self._key_dtype,
value_dtype=value_dtype,
value_dtype=self._value_dtype,
default_value=oov_value,
name=(self._name + "_index_table"))
tracked_table = self._add_trackable(self._table, trainable=False)
@ -201,11 +200,14 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
self.set_vocabulary(vocabulary)
def compute_output_shape(self, input_shape):
if self.output_mode != INT:
return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
return input_shape
def compute_output_signature(self, input_spec):
output_shape = self.compute_output_shape(input_spec.shape.as_list())
output_dtype = self.dtype if self.invert else self._output_dtype
output_dtype = self._value_dtype if self.output_mode == INT else K.floatx()
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
def adapt(self, data, reset_state=True):

View File

@ -618,8 +618,8 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_output_shape(self):
input_data = keras.Input(shape=(4,), dtype=dtypes.string)
def test_int_output_shape(self):
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
layer = get_layer_class()(
max_tokens=2,
num_oov_indices=1,
@ -627,7 +627,7 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
oov_token="[OOV]",
dtype=dtypes.string)
int_data = layer(input_data)
self.assertAllEqual(int_data.shape[1:], input_data.shape[1:])
self.assertAllEqual(int_data.shape.as_list(), [16, 4])
def test_int_output_no_reserved_zero(self):
vocab_data = ["earth", "wind", "and", "fire"]
@ -667,6 +667,70 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_binary_output(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
expected_output = [[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]]
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = get_layer_class()(
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
output_mode=index_lookup.BINARY,
dtype=dtypes.string)
layer.set_vocabulary(vocab_data)
binary_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=binary_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_binary_output_shape(self):
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
layer = get_layer_class()(
max_tokens=2,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
output_mode=index_lookup.BINARY,
dtype=dtypes.string)
binary_data = layer(input_data)
self.assertAllEqual(binary_data.shape.as_list(), [16, 2])
def test_count_output(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "wind"],
["fire", "fire", "fire", "michigan"]])
expected_output = [[0, 0, 1, 2, 1, 0], [0, 1, 0, 0, 0, 3]]
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = get_layer_class()(
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
output_mode=index_lookup.COUNT,
dtype=dtypes.string)
layer.set_vocabulary(vocab_data)
count_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=count_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_count_output_shape(self):
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
layer = get_layer_class()(
max_tokens=2,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
output_mode=index_lookup.COUNT,
dtype=dtypes.string)
count_data = layer(input_data)
self.assertAllEqual(count_data.shape.as_list(), [16, 2])
@keras_parameterized.run_all_keras_modes
class IndexLookupVocabularyTest(keras_parameterized.TestCase,