Fixup output shape for IntegerLookup/StringLookup layers
This makes the following fixes for BINARY and COUNT output - Fixes compute_output_shape and compute_output_signature - Properly propogates batch shape for dense inputs - Adds test coverage PiperOrigin-RevId: 355071871 Change-Id: I7820763100b643b8cd12908caf416aae1c4a1f14
This commit is contained in:
parent
229cbce4ca
commit
5031630885
@ -160,6 +160,7 @@ py_library(
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
||||
@ -534,5 +534,6 @@ def dense_bincount(inputs, out_depth, binary_output, count_weights=None):
|
||||
dtype=K.floatx(),
|
||||
axis=-1,
|
||||
binary_output=binary_output)
|
||||
result.set_shape(tensor_shape.TensorShape((None, out_depth)))
|
||||
batch_size = inputs.shape.as_list()[0]
|
||||
result.set_shape(tensor_shape.TensorShape((batch_size, out_depth)))
|
||||
return result
|
||||
|
||||
@ -27,6 +27,7 @@ import numpy as np
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||
from tensorflow.python.keras.layers.preprocessing import category_encoding
|
||||
from tensorflow.python.keras.layers.preprocessing import table_utils
|
||||
@ -160,22 +161,20 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
super(IndexLookup, self).__init__(
|
||||
combiner=_IndexLookupCombiner(vocab_size, self.mask_token), **kwargs)
|
||||
|
||||
self._output_dtype = dtypes.int64
|
||||
|
||||
# We need to save the key dtype so that we know if we're expecting int64
|
||||
# keys. If we are, we will cast int32 inputs to int64 as well.
|
||||
if invert:
|
||||
self._key_dtype = self._output_dtype
|
||||
value_dtype = self.dtype
|
||||
self._key_dtype = dtypes.int64
|
||||
self._value_dtype = self.dtype
|
||||
oov_value = self.oov_token
|
||||
else:
|
||||
self._key_dtype = self.dtype
|
||||
value_dtype = self._output_dtype
|
||||
self._value_dtype = dtypes.int64
|
||||
oov_value = self._oov_value
|
||||
|
||||
self._table = lookup_ops.MutableHashTable(
|
||||
key_dtype=self._key_dtype,
|
||||
value_dtype=value_dtype,
|
||||
value_dtype=self._value_dtype,
|
||||
default_value=oov_value,
|
||||
name=(self._name + "_index_table"))
|
||||
tracked_table = self._add_trackable(self._table, trainable=False)
|
||||
@ -201,11 +200,14 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
self.set_vocabulary(vocabulary)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
if self.output_mode != INT:
|
||||
return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
|
||||
|
||||
return input_shape
|
||||
|
||||
def compute_output_signature(self, input_spec):
|
||||
output_shape = self.compute_output_shape(input_spec.shape.as_list())
|
||||
output_dtype = self.dtype if self.invert else self._output_dtype
|
||||
output_dtype = self._value_dtype if self.output_mode == INT else K.floatx()
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
||||
|
||||
def adapt(self, data, reset_state=True):
|
||||
|
||||
@ -618,8 +618,8 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_output_shape(self):
|
||||
input_data = keras.Input(shape=(4,), dtype=dtypes.string)
|
||||
def test_int_output_shape(self):
|
||||
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=2,
|
||||
num_oov_indices=1,
|
||||
@ -627,7 +627,7 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||
oov_token="[OOV]",
|
||||
dtype=dtypes.string)
|
||||
int_data = layer(input_data)
|
||||
self.assertAllEqual(int_data.shape[1:], input_data.shape[1:])
|
||||
self.assertAllEqual(int_data.shape.as_list(), [16, 4])
|
||||
|
||||
def test_int_output_no_reserved_zero(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
@ -667,6 +667,70 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_binary_output(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "michigan"]])
|
||||
expected_output = [[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]]
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
output_mode=index_lookup.BINARY,
|
||||
dtype=dtypes.string)
|
||||
layer.set_vocabulary(vocab_data)
|
||||
binary_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=binary_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_binary_output_shape(self):
|
||||
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=2,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
output_mode=index_lookup.BINARY,
|
||||
dtype=dtypes.string)
|
||||
binary_data = layer(input_data)
|
||||
self.assertAllEqual(binary_data.shape.as_list(), [16, 2])
|
||||
|
||||
def test_count_output(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "wind"],
|
||||
["fire", "fire", "fire", "michigan"]])
|
||||
expected_output = [[0, 0, 1, 2, 1, 0], [0, 1, 0, 0, 0, 3]]
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
output_mode=index_lookup.COUNT,
|
||||
dtype=dtypes.string)
|
||||
layer.set_vocabulary(vocab_data)
|
||||
count_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=count_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_count_output_shape(self):
|
||||
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(
|
||||
max_tokens=2,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
output_mode=index_lookup.COUNT,
|
||||
dtype=dtypes.string)
|
||||
count_data = layer(input_data)
|
||||
self.assertAllEqual(count_data.shape.as_list(), [16, 2])
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user