Fixup output shape for IntegerLookup/StringLookup layers
This makes the following fixes for BINARY and COUNT output - Fixes compute_output_shape and compute_output_signature - Properly propogates batch shape for dense inputs - Adds test coverage PiperOrigin-RevId: 355071871 Change-Id: I7820763100b643b8cd12908caf416aae1c4a1f14
This commit is contained in:
parent
229cbce4ca
commit
5031630885
@ -160,6 +160,7 @@ py_library(
|
|||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:tensor_spec",
|
"//tensorflow/python:tensor_spec",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/keras:backend",
|
||||||
"//tensorflow/python/keras/engine",
|
"//tensorflow/python/keras/engine",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
|
|||||||
@ -534,5 +534,6 @@ def dense_bincount(inputs, out_depth, binary_output, count_weights=None):
|
|||||||
dtype=K.floatx(),
|
dtype=K.floatx(),
|
||||||
axis=-1,
|
axis=-1,
|
||||||
binary_output=binary_output)
|
binary_output=binary_output)
|
||||||
result.set_shape(tensor_shape.TensorShape((None, out_depth)))
|
batch_size = inputs.shape.as_list()[0]
|
||||||
|
result.set_shape(tensor_shape.TensorShape((batch_size, out_depth)))
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import numpy as np
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||||
from tensorflow.python.keras.layers.preprocessing import category_encoding
|
from tensorflow.python.keras.layers.preprocessing import category_encoding
|
||||||
from tensorflow.python.keras.layers.preprocessing import table_utils
|
from tensorflow.python.keras.layers.preprocessing import table_utils
|
||||||
@ -160,22 +161,20 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
super(IndexLookup, self).__init__(
|
super(IndexLookup, self).__init__(
|
||||||
combiner=_IndexLookupCombiner(vocab_size, self.mask_token), **kwargs)
|
combiner=_IndexLookupCombiner(vocab_size, self.mask_token), **kwargs)
|
||||||
|
|
||||||
self._output_dtype = dtypes.int64
|
|
||||||
|
|
||||||
# We need to save the key dtype so that we know if we're expecting int64
|
# We need to save the key dtype so that we know if we're expecting int64
|
||||||
# keys. If we are, we will cast int32 inputs to int64 as well.
|
# keys. If we are, we will cast int32 inputs to int64 as well.
|
||||||
if invert:
|
if invert:
|
||||||
self._key_dtype = self._output_dtype
|
self._key_dtype = dtypes.int64
|
||||||
value_dtype = self.dtype
|
self._value_dtype = self.dtype
|
||||||
oov_value = self.oov_token
|
oov_value = self.oov_token
|
||||||
else:
|
else:
|
||||||
self._key_dtype = self.dtype
|
self._key_dtype = self.dtype
|
||||||
value_dtype = self._output_dtype
|
self._value_dtype = dtypes.int64
|
||||||
oov_value = self._oov_value
|
oov_value = self._oov_value
|
||||||
|
|
||||||
self._table = lookup_ops.MutableHashTable(
|
self._table = lookup_ops.MutableHashTable(
|
||||||
key_dtype=self._key_dtype,
|
key_dtype=self._key_dtype,
|
||||||
value_dtype=value_dtype,
|
value_dtype=self._value_dtype,
|
||||||
default_value=oov_value,
|
default_value=oov_value,
|
||||||
name=(self._name + "_index_table"))
|
name=(self._name + "_index_table"))
|
||||||
tracked_table = self._add_trackable(self._table, trainable=False)
|
tracked_table = self._add_trackable(self._table, trainable=False)
|
||||||
@ -201,11 +200,14 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
self.set_vocabulary(vocabulary)
|
self.set_vocabulary(vocabulary)
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
|
if self.output_mode != INT:
|
||||||
|
return tensor_shape.TensorShape([input_shape[0], self.max_tokens])
|
||||||
|
|
||||||
return input_shape
|
return input_shape
|
||||||
|
|
||||||
def compute_output_signature(self, input_spec):
|
def compute_output_signature(self, input_spec):
|
||||||
output_shape = self.compute_output_shape(input_spec.shape.as_list())
|
output_shape = self.compute_output_shape(input_spec.shape.as_list())
|
||||||
output_dtype = self.dtype if self.invert else self._output_dtype
|
output_dtype = self._value_dtype if self.output_mode == INT else K.floatx()
|
||||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
||||||
|
|
||||||
def adapt(self, data, reset_state=True):
|
def adapt(self, data, reset_state=True):
|
||||||
|
|||||||
@ -618,8 +618,8 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
|||||||
output_dataset = model.predict(input_array)
|
output_dataset = model.predict(input_array)
|
||||||
self.assertAllEqual(expected_output, output_dataset)
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
def test_output_shape(self):
|
def test_int_output_shape(self):
|
||||||
input_data = keras.Input(shape=(4,), dtype=dtypes.string)
|
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
|
||||||
layer = get_layer_class()(
|
layer = get_layer_class()(
|
||||||
max_tokens=2,
|
max_tokens=2,
|
||||||
num_oov_indices=1,
|
num_oov_indices=1,
|
||||||
@ -627,7 +627,7 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
|||||||
oov_token="[OOV]",
|
oov_token="[OOV]",
|
||||||
dtype=dtypes.string)
|
dtype=dtypes.string)
|
||||||
int_data = layer(input_data)
|
int_data = layer(input_data)
|
||||||
self.assertAllEqual(int_data.shape[1:], input_data.shape[1:])
|
self.assertAllEqual(int_data.shape.as_list(), [16, 4])
|
||||||
|
|
||||||
def test_int_output_no_reserved_zero(self):
|
def test_int_output_no_reserved_zero(self):
|
||||||
vocab_data = ["earth", "wind", "and", "fire"]
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
@ -667,6 +667,70 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
|||||||
output_dataset = model.predict(input_array)
|
output_dataset = model.predict(input_array)
|
||||||
self.assertAllEqual(expected_output, output_dataset)
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
|
def test_binary_output(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
|
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||||
|
["fire", "and", "earth", "michigan"]])
|
||||||
|
expected_output = [[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]]
|
||||||
|
|
||||||
|
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||||
|
layer = get_layer_class()(
|
||||||
|
max_tokens=None,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token="",
|
||||||
|
oov_token="[OOV]",
|
||||||
|
output_mode=index_lookup.BINARY,
|
||||||
|
dtype=dtypes.string)
|
||||||
|
layer.set_vocabulary(vocab_data)
|
||||||
|
binary_data = layer(input_data)
|
||||||
|
model = keras.Model(inputs=input_data, outputs=binary_data)
|
||||||
|
output_dataset = model.predict(input_array)
|
||||||
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
|
def test_binary_output_shape(self):
|
||||||
|
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
|
||||||
|
layer = get_layer_class()(
|
||||||
|
max_tokens=2,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token="",
|
||||||
|
oov_token="[OOV]",
|
||||||
|
output_mode=index_lookup.BINARY,
|
||||||
|
dtype=dtypes.string)
|
||||||
|
binary_data = layer(input_data)
|
||||||
|
self.assertAllEqual(binary_data.shape.as_list(), [16, 2])
|
||||||
|
|
||||||
|
def test_count_output(self):
|
||||||
|
vocab_data = ["earth", "wind", "and", "fire"]
|
||||||
|
input_array = np.array([["earth", "wind", "and", "wind"],
|
||||||
|
["fire", "fire", "fire", "michigan"]])
|
||||||
|
expected_output = [[0, 0, 1, 2, 1, 0], [0, 1, 0, 0, 0, 3]]
|
||||||
|
|
||||||
|
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||||
|
layer = get_layer_class()(
|
||||||
|
max_tokens=None,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token="",
|
||||||
|
oov_token="[OOV]",
|
||||||
|
output_mode=index_lookup.COUNT,
|
||||||
|
dtype=dtypes.string)
|
||||||
|
layer.set_vocabulary(vocab_data)
|
||||||
|
count_data = layer(input_data)
|
||||||
|
model = keras.Model(inputs=input_data, outputs=count_data)
|
||||||
|
output_dataset = model.predict(input_array)
|
||||||
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
|
def test_count_output_shape(self):
|
||||||
|
input_data = keras.Input(batch_size=16, shape=(4,), dtype=dtypes.string)
|
||||||
|
layer = get_layer_class()(
|
||||||
|
max_tokens=2,
|
||||||
|
num_oov_indices=1,
|
||||||
|
mask_token="",
|
||||||
|
oov_token="[OOV]",
|
||||||
|
output_mode=index_lookup.COUNT,
|
||||||
|
dtype=dtypes.string)
|
||||||
|
count_data = layer(input_data)
|
||||||
|
self.assertAllEqual(count_data.shape.as_list(), [16, 2])
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
|
class IndexLookupVocabularyTest(keras_parameterized.TestCase,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user