Allow index_lookup to perform inverse lookups.

PiperOrigin-RevId: 293471194
Change-Id: Idd2d08e8663fb8ae69df3eb6923e1e8fa876d154
This commit is contained in:
A. Unique TensorFlower 2020-02-05 15:52:23 -08:00 committed by TensorFlower Gardener
parent a3594d8126
commit 0f27efc0f8
3 changed files with 120 additions and 16 deletions

View File

@ -73,6 +73,12 @@ class IndexLookup(CombinerPreprocessingLayer):
mask_zero: If True, input values of 0 (for integers) and `""` (for strings)
will be treated as masked values and assigned an output value of 0. If
this option is set, `reserve_zero` must also be set. Defaults to False.
Call arguments:
inputs: The data to look up. Can be a tf.Tensor or RaggedTensor.
invert: Controls the lookup direction. If False, the layer will map strings
to integers; if true, the layer will map integers to strings. Defaults
to False.
"""
# TODO(momernick): Add an examples section to the docstring.
@ -141,12 +147,17 @@ class IndexLookup(CombinerPreprocessingLayer):
self._output_dtype = dtypes.int32
else:
self._output_dtype = dtypes.int64
self._table = lookup_ops.MutableHashTable(
key_dtype=self.dtype,
value_dtype=self._output_dtype,
default_value=self._oov_value,
name=(self._name + "_index_table"))
tracked_table = self._add_trackable(self._table, trainable=False)
# This is a workaround for summary() on this layer. Because the table is
# not mutable during training, the effective number of parameters (and so
# the weight shape) is 0; we add this as an attr so that the parameter
# counting code in the Model object doesn't throw an attribute error.
tracked_table.shape = tensor_shape.TensorShape((0,))
# This is a workaround for saving not working yet for MutableHashTables.
# By replacing the existing function call by an explicit failure, we
@ -154,15 +165,8 @@ class IndexLookup(CombinerPreprocessingLayer):
def fail(_):
raise NotImplementedError(
"Saving is not yet supported for IndexLookup layers.")
self._table._list_extra_dependencies_for_serialization = fail # pylint: disable=protected-access
tracked_table = self._add_trackable(self._table, trainable=False)
# This is a workaround for summary() on this layer. Because the table is
# not mutable during training, the effective number of parameters (and so
# the weight shape) is 0; we add this as an attr so that the parameter
# counting code in the Model object doesn't throw an attribute error.
tracked_table.shape = tensor_shape.TensorShape((0,))
self._inverse_table = None
def _get_table_data(self):
keys, values = self._table.export()
@ -174,6 +178,9 @@ class IndexLookup(CombinerPreprocessingLayer):
def _clear_table(self):
keys, _ = self._table.export()
self._table.remove(keys)
if self._inverse_table:
keys, _ = self._inverse_table.export()
self._inverse_table.remove(keys)
def _insert_table_data(self, keys, values):
if len(values) != len(keys):
@ -181,6 +188,12 @@ class IndexLookup(CombinerPreprocessingLayer):
"Keys had size %s, values had size %s." %
(len(keys), len(values)))
self._table.insert(keys, values)
if self._inverse_table:
self._inverse_table.insert(values, keys)
def _initialize_inverse_table(self):
keys, values = self._table.export()
self._inverse_table.insert(values, keys)
def _to_numpy(self, preprocessed_data):
"""Converts preprocessed inputs into numpy arrays."""
@ -207,9 +220,12 @@ class IndexLookup(CombinerPreprocessingLayer):
def compute_output_shape(self, input_shape):
return input_shape
def compute_output_signature(self, input_spec):
def compute_output_signature(self, input_spec, invert=False):
output_shape = self.compute_output_shape(input_spec.shape.as_list())
output_dtype = dtypes.int64
if invert:
output_dtype = dtypes.string
else:
output_dtype = dtypes.int64
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
def adapt(self, data, reset_state=True):
@ -243,7 +259,7 @@ class IndexLookup(CombinerPreprocessingLayer):
"max_tokens": self.max_tokens,
"num_oov_tokens": self.num_oov_tokens,
"reserve_zero": self.reserve_zero,
"mask_zero": self.mask_zero
"mask_zero": self.mask_zero,
}
base_config = super(IndexLookup, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@ -301,14 +317,45 @@ class IndexLookup(CombinerPreprocessingLayer):
raise RuntimeError("_set_state_variables() must be called after build().")
self.set_vocabulary(updates[_VOCAB_NAME])
def call(self, inputs):
def __call__(self, inputs, invert=False, **kwargs):
if invert and not self._inverse_table:
# If the user wants to perform an inverse lookup, we need to build an
# inverse lookup table and initialize it to have the inverse of the
# forward table's vocabulary.
self._inverse_table = lookup_ops.MutableHashTable(
key_dtype=self._output_dtype,
value_dtype=self.dtype,
default_value="",
name=(self._name + "_inverse_index_table"))
tracked_inverse_table = self._add_trackable(
self._inverse_table, trainable=False)
# This is a workaround for summary() on this layer. Because the table is
# not mutable during training, the effective number of parameters (and so
# the weight shape) is 0; we add this as an attr so that the parameter
# counting code in the Model object doesn't throw an attribute error.
tracked_inverse_table.shape = tensor_shape.TensorShape((0,))
# This is a workaround for saving not working yet for MutableHashTables.
# By replacing the existing function call by an explicit failure, we
# can provide a more user-friendly error message.
def fail(_):
raise NotImplementedError(
"Saving is not yet supported for IndexLookup layers.")
self._inverse_table._list_extra_dependencies_for_serialization = fail # pylint: disable=protected-access
self._initialize_inverse_table()
return super(IndexLookup, self).__call__(inputs, invert=invert, **kwargs)
def call(self, inputs, invert=False):
table = self._inverse_table if invert else self._table
# The table lookup ops don't natively support ragged tensors, so if we have
# a RT we need to use map_flat_values to look up every element.
if ragged_tensor.is_ragged(inputs):
indexed_data = ragged_functional_ops.map_flat_values(
self._table.lookup, inputs)
indexed_data = ragged_functional_ops.map_flat_values(table.lookup, inputs)
else:
indexed_data = self._table.lookup(inputs)
indexed_data = table.lookup(inputs)
return indexed_data

View File

@ -193,6 +193,54 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
self.assertAllClose(expected_output, output_dataset)
@keras_parameterized.run_all_keras_modes
class InverseLookupOutputTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_inverse_output(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
expected_ints = [[2, 3, 4, 5], [5, 4, 2, 1]]
# Note that the token 'michigan' has been replaced by ''. This is because
# 'michigan' is OOV for this layer.
expected_strings = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", ""]])
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = get_layer_class()(max_tokens=None)
layer.set_vocabulary(vocab_data)
int_data = layer(input_data)
string_data = layer(int_data, invert=True)
model = keras.Model(inputs=input_data, outputs=[int_data, string_data])
int_outputs, string_outputs = model.predict(input_array)
self.assertAllEqual(expected_ints, int_outputs)
self.assertAllEqual(expected_strings, string_outputs)
def test_inverse_output_serialization(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
expected_ints = [[2, 3, 4, 5], [5, 4, 2, 1]]
# Note that the token 'michigan' has been replaced by ''. This is because
# 'michigan' is OOV for this layer.
expected_strings = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", ""]])
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = get_layer_class()(max_tokens=None)
layer.set_vocabulary(vocab_data)
int_data = layer(input_data)
string_data = layer(int_data, invert=True)
model = keras.Model(inputs=input_data, outputs=[int_data, string_data])
with CustomObjectScope({"IndexLookup": get_layer_class()}):
new_model = keras.Model.from_config(model.get_config())
new_model.set_weights(model.get_weights())
int_outputs, string_outputs = new_model.predict(input_array)
self.assertAllEqual(expected_ints, int_outputs)
self.assertAllEqual(expected_strings, string_outputs)
@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
class IndexLookupSaveableTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):

View File

@ -71,9 +71,18 @@ class IndexLookup(index_lookup.IndexLookup,
def _clear_table(self):
keys, _ = self._table.export()
K.get_session().run(self._table.remove(keys))
if self._inverse_table:
keys, _ = self._inverse_table.export()
K.get_session().run(self._inverse_table.remove(keys))
def _insert_table_data(self, keys, values):
K.get_session().run(self._table.insert(keys, values))
if self._inverse_table:
K.get_session().run(self._inverse_table.insert(values, keys))
def _initialize_inverse_table(self):
keys, values = self._table.export()
K.get_session().run(self._inverse_table.insert(values, keys))
def _to_numpy(self, data):
"""Converts preprocessed inputs into numpy arrays."""