Allow index_lookup to perform inverse lookups.
PiperOrigin-RevId: 293471194 Change-Id: Idd2d08e8663fb8ae69df3eb6923e1e8fa876d154
This commit is contained in:
parent
a3594d8126
commit
0f27efc0f8
@ -73,6 +73,12 @@ class IndexLookup(CombinerPreprocessingLayer):
|
||||
mask_zero: If True, input values of 0 (for integers) and `""` (for strings)
|
||||
will be treated as masked values and assigned an output value of 0. If
|
||||
this option is set, `reserve_zero` must also be set. Defaults to False.
|
||||
|
||||
Call arguments:
|
||||
inputs: The data to look up. Can be a tf.Tensor or RaggedTensor.
|
||||
invert: Controls the lookup direction. If False, the layer will map strings
|
||||
to integers; if true, the layer will map integers to strings. Defaults
|
||||
to False.
|
||||
"""
|
||||
# TODO(momernick): Add an examples section to the docstring.
|
||||
|
||||
@ -141,12 +147,17 @@ class IndexLookup(CombinerPreprocessingLayer):
|
||||
self._output_dtype = dtypes.int32
|
||||
else:
|
||||
self._output_dtype = dtypes.int64
|
||||
|
||||
self._table = lookup_ops.MutableHashTable(
|
||||
key_dtype=self.dtype,
|
||||
value_dtype=self._output_dtype,
|
||||
default_value=self._oov_value,
|
||||
name=(self._name + "_index_table"))
|
||||
tracked_table = self._add_trackable(self._table, trainable=False)
|
||||
# This is a workaround for summary() on this layer. Because the table is
|
||||
# not mutable during training, the effective number of parameters (and so
|
||||
# the weight shape) is 0; we add this as an attr so that the parameter
|
||||
# counting code in the Model object doesn't throw an attribute error.
|
||||
tracked_table.shape = tensor_shape.TensorShape((0,))
|
||||
|
||||
# This is a workaround for saving not working yet for MutableHashTables.
|
||||
# By replacing the existing function call by an explicit failure, we
|
||||
@ -154,15 +165,8 @@ class IndexLookup(CombinerPreprocessingLayer):
|
||||
def fail(_):
|
||||
raise NotImplementedError(
|
||||
"Saving is not yet supported for IndexLookup layers.")
|
||||
|
||||
self._table._list_extra_dependencies_for_serialization = fail # pylint: disable=protected-access
|
||||
|
||||
tracked_table = self._add_trackable(self._table, trainable=False)
|
||||
# This is a workaround for summary() on this layer. Because the table is
|
||||
# not mutable during training, the effective number of parameters (and so
|
||||
# the weight shape) is 0; we add this as an attr so that the parameter
|
||||
# counting code in the Model object doesn't throw an attribute error.
|
||||
tracked_table.shape = tensor_shape.TensorShape((0,))
|
||||
self._inverse_table = None
|
||||
|
||||
def _get_table_data(self):
|
||||
keys, values = self._table.export()
|
||||
@ -174,6 +178,9 @@ class IndexLookup(CombinerPreprocessingLayer):
|
||||
def _clear_table(self):
|
||||
keys, _ = self._table.export()
|
||||
self._table.remove(keys)
|
||||
if self._inverse_table:
|
||||
keys, _ = self._inverse_table.export()
|
||||
self._inverse_table.remove(keys)
|
||||
|
||||
def _insert_table_data(self, keys, values):
|
||||
if len(values) != len(keys):
|
||||
@ -181,6 +188,12 @@ class IndexLookup(CombinerPreprocessingLayer):
|
||||
"Keys had size %s, values had size %s." %
|
||||
(len(keys), len(values)))
|
||||
self._table.insert(keys, values)
|
||||
if self._inverse_table:
|
||||
self._inverse_table.insert(values, keys)
|
||||
|
||||
def _initialize_inverse_table(self):
|
||||
keys, values = self._table.export()
|
||||
self._inverse_table.insert(values, keys)
|
||||
|
||||
def _to_numpy(self, preprocessed_data):
|
||||
"""Converts preprocessed inputs into numpy arrays."""
|
||||
@ -207,9 +220,12 @@ class IndexLookup(CombinerPreprocessingLayer):
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def compute_output_signature(self, input_spec):
|
||||
def compute_output_signature(self, input_spec, invert=False):
|
||||
output_shape = self.compute_output_shape(input_spec.shape.as_list())
|
||||
output_dtype = dtypes.int64
|
||||
if invert:
|
||||
output_dtype = dtypes.string
|
||||
else:
|
||||
output_dtype = dtypes.int64
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
||||
|
||||
def adapt(self, data, reset_state=True):
|
||||
@ -243,7 +259,7 @@ class IndexLookup(CombinerPreprocessingLayer):
|
||||
"max_tokens": self.max_tokens,
|
||||
"num_oov_tokens": self.num_oov_tokens,
|
||||
"reserve_zero": self.reserve_zero,
|
||||
"mask_zero": self.mask_zero
|
||||
"mask_zero": self.mask_zero,
|
||||
}
|
||||
base_config = super(IndexLookup, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
@ -301,14 +317,45 @@ class IndexLookup(CombinerPreprocessingLayer):
|
||||
raise RuntimeError("_set_state_variables() must be called after build().")
|
||||
self.set_vocabulary(updates[_VOCAB_NAME])
|
||||
|
||||
def call(self, inputs):
|
||||
def __call__(self, inputs, invert=False, **kwargs):
|
||||
if invert and not self._inverse_table:
|
||||
# If the user wants to perform an inverse lookup, we need to build an
|
||||
# inverse lookup table and initialize it to have the inverse of the
|
||||
# forward table's vocabulary.
|
||||
self._inverse_table = lookup_ops.MutableHashTable(
|
||||
key_dtype=self._output_dtype,
|
||||
value_dtype=self.dtype,
|
||||
default_value="",
|
||||
name=(self._name + "_inverse_index_table"))
|
||||
|
||||
tracked_inverse_table = self._add_trackable(
|
||||
self._inverse_table, trainable=False)
|
||||
# This is a workaround for summary() on this layer. Because the table is
|
||||
# not mutable during training, the effective number of parameters (and so
|
||||
# the weight shape) is 0; we add this as an attr so that the parameter
|
||||
# counting code in the Model object doesn't throw an attribute error.
|
||||
tracked_inverse_table.shape = tensor_shape.TensorShape((0,))
|
||||
|
||||
# This is a workaround for saving not working yet for MutableHashTables.
|
||||
# By replacing the existing function call by an explicit failure, we
|
||||
# can provide a more user-friendly error message.
|
||||
def fail(_):
|
||||
raise NotImplementedError(
|
||||
"Saving is not yet supported for IndexLookup layers.")
|
||||
|
||||
self._inverse_table._list_extra_dependencies_for_serialization = fail # pylint: disable=protected-access
|
||||
self._initialize_inverse_table()
|
||||
|
||||
return super(IndexLookup, self).__call__(inputs, invert=invert, **kwargs)
|
||||
|
||||
def call(self, inputs, invert=False):
|
||||
table = self._inverse_table if invert else self._table
|
||||
# The table lookup ops don't natively support ragged tensors, so if we have
|
||||
# a RT we need to use map_flat_values to look up every element.
|
||||
if ragged_tensor.is_ragged(inputs):
|
||||
indexed_data = ragged_functional_ops.map_flat_values(
|
||||
self._table.lookup, inputs)
|
||||
indexed_data = ragged_functional_ops.map_flat_values(table.lookup, inputs)
|
||||
else:
|
||||
indexed_data = self._table.lookup(inputs)
|
||||
indexed_data = table.lookup(inputs)
|
||||
|
||||
return indexed_data
|
||||
|
||||
|
@ -193,6 +193,54 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||
self.assertAllClose(expected_output, output_dataset)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class InverseLookupOutputTest(keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def test_inverse_output(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "michigan"]])
|
||||
expected_ints = [[2, 3, 4, 5], [5, 4, 2, 1]]
|
||||
# Note that the token 'michigan' has been replaced by ''. This is because
|
||||
# 'michigan' is OOV for this layer.
|
||||
expected_strings = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", ""]])
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(max_tokens=None)
|
||||
layer.set_vocabulary(vocab_data)
|
||||
int_data = layer(input_data)
|
||||
string_data = layer(int_data, invert=True)
|
||||
model = keras.Model(inputs=input_data, outputs=[int_data, string_data])
|
||||
int_outputs, string_outputs = model.predict(input_array)
|
||||
self.assertAllEqual(expected_ints, int_outputs)
|
||||
self.assertAllEqual(expected_strings, string_outputs)
|
||||
|
||||
def test_inverse_output_serialization(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "michigan"]])
|
||||
expected_ints = [[2, 3, 4, 5], [5, 4, 2, 1]]
|
||||
# Note that the token 'michigan' has been replaced by ''. This is because
|
||||
# 'michigan' is OOV for this layer.
|
||||
expected_strings = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", ""]])
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(max_tokens=None)
|
||||
layer.set_vocabulary(vocab_data)
|
||||
int_data = layer(input_data)
|
||||
string_data = layer(int_data, invert=True)
|
||||
model = keras.Model(inputs=input_data, outputs=[int_data, string_data])
|
||||
|
||||
with CustomObjectScope({"IndexLookup": get_layer_class()}):
|
||||
new_model = keras.Model.from_config(model.get_config())
|
||||
new_model.set_weights(model.get_weights())
|
||||
int_outputs, string_outputs = new_model.predict(input_array)
|
||||
self.assertAllEqual(expected_ints, int_outputs)
|
||||
self.assertAllEqual(expected_strings, string_outputs)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
|
||||
class IndexLookupSaveableTest(keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
@ -71,9 +71,18 @@ class IndexLookup(index_lookup.IndexLookup,
|
||||
def _clear_table(self):
|
||||
keys, _ = self._table.export()
|
||||
K.get_session().run(self._table.remove(keys))
|
||||
if self._inverse_table:
|
||||
keys, _ = self._inverse_table.export()
|
||||
K.get_session().run(self._inverse_table.remove(keys))
|
||||
|
||||
def _insert_table_data(self, keys, values):
|
||||
K.get_session().run(self._table.insert(keys, values))
|
||||
if self._inverse_table:
|
||||
K.get_session().run(self._inverse_table.insert(values, keys))
|
||||
|
||||
def _initialize_inverse_table(self):
|
||||
keys, values = self._table.export()
|
||||
K.get_session().run(self._inverse_table.insert(values, keys))
|
||||
|
||||
def _to_numpy(self, data):
|
||||
"""Converts preprocessed inputs into numpy arrays."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user