Add an 'invert' arg to lookup layers.

PiperOrigin-RevId: 312329926
Change-Id: If00e4f169412d7b8e5ebc2b74dae65ade4b0fd0a
This commit is contained in:
A. Unique TensorFlower 2020-05-19 12:28:42 -07:00 committed by TensorFlower Gardener
parent 34a68f2752
commit c12107003b
7 changed files with 324 additions and 39 deletions

View File

@ -75,6 +75,8 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
only used when performing an inverse lookup.
vocabulary: An optional list of vocabulary terms. If the list contains the
same token multiple times, an error will be thrown.
invert: If true, this layer will map indices to vocabulary items instead
of mapping vocabulary items to indices.
"""
# TODO(momernick): Add an examples section to the docstring.
@ -84,17 +86,22 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
mask_token,
oov_token,
vocabulary=None,
invert=False,
**kwargs):
# If max_tokens is set, the value must be greater than 1 - otherwise we
# are creating a 0-element vocab, which doesn't make sense.
if max_tokens is not None and max_tokens <= 1:
raise ValueError("If set, max_tokens must be greater than 1.")
raise ValueError("If set, `max_tokens` must be greater than 1.")
if num_oov_indices < 0:
raise ValueError("num_oov_indices must be greater than 0. You passed %s" %
num_oov_indices)
raise ValueError("`num_oov_indices` must be greater than 0. You passed "
"%s" % num_oov_indices)
if invert and num_oov_indices != 1:
raise ValueError("`num_oov_tokens` must be 1 when `invert` is True.")
self.invert = invert
self.max_tokens = max_tokens
self.num_oov_indices = num_oov_indices
self.oov_token = oov_token
@ -117,10 +124,19 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
self._output_dtype = dtypes.int64
if invert:
key_dtype = self._output_dtype
value_dtype = self.dtype
oov_value = self.oov_token
else:
key_dtype = self.dtype
value_dtype = self._output_dtype
oov_value = self._oov_value
self._table = lookup_ops.MutableHashTable(
key_dtype=self.dtype,
value_dtype=self._output_dtype,
default_value=self._oov_value,
key_dtype=key_dtype,
value_dtype=value_dtype,
default_value=oov_value,
name=(self._name + "_index_table"))
tracked_table = self._add_trackable(self._table, trainable=False)
# This is a workaround for summary() on this layer. Because the table is
@ -149,7 +165,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
def compute_output_signature(self, input_spec):
output_shape = self.compute_output_shape(input_spec.shape.as_list())
output_dtype = dtypes.int64
output_dtype = self.dtype if self.invert else self._output_dtype
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
def adapt(self, data, reset_state=True):
@ -176,13 +192,18 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
keys, values = self._table_handler.data()
# This is required because the MutableHashTable doesn't preserve insertion
# order, but we rely on the order of the array to assign indices.
return [x for _, x in sorted(zip(values, keys))]
if self.invert:
# If we are inverting, the vocabulary is in the values instead of keys.
return [x for _, x in sorted(zip(keys, values))]
else:
return [x for _, x in sorted(zip(values, keys))]
def vocab_size(self):
return self._table_handler.vocab_size()
def get_config(self):
config = {
"invert": self.invert,
"max_tokens": self.max_tokens,
"num_oov_indices": self.num_oov_indices,
"oov_token": self.oov_token,
@ -198,33 +219,15 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
# abstraction for ease of saving!) we return 0.
return 0
def set_vocabulary(self, vocab):
"""Sets vocabulary (and optionally document frequency) data for this layer.
This method sets the vocabulary for this layer directly, instead of
analyzing a dataset through 'adapt'. It should be used whenever the vocab
information is already known. If vocabulary data is already present in the
layer, this method will either replace it
Arguments:
vocab: An array of string tokens.
Raises:
ValueError: If there are too many inputs, the inputs do not match, or
input data is missing.
"""
def _set_forward_vocabulary(self, vocab):
"""Sets vocabulary data for this layer when inverse is False."""
table_utils.validate_vocabulary_is_unique(vocab)
should_have_mask = self.mask_token is not None
if should_have_mask:
has_mask = vocab[0] == self.mask_token
oov_start = 1
else:
has_mask = False
oov_start = 0
has_mask = vocab[0] == self.mask_token
oov_start = 1 if should_have_mask else 0
should_have_oov = self.num_oov_indices > 0
should_have_oov = (self.num_oov_indices > 0) and not self.invert
if should_have_oov:
oov_end = oov_start + self.num_oov_indices
expected_oov = [self.oov_token] * self.num_oov_indices
@ -293,6 +296,65 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
special_token_values = np.arange(num_special_tokens, dtype=np.int64)
self._table_handler.insert(special_tokens, special_token_values)
def _set_inverse_vocabulary(self, vocab):
"""Sets vocabulary data for this layer when inverse is True."""
table_utils.validate_vocabulary_is_unique(vocab)
should_have_mask = self.mask_token is not None
has_mask = vocab[0] == self.mask_token
insert_special_tokens = should_have_mask and not has_mask
special_tokens = [] if self.mask_token is None else [self.mask_token]
num_special_tokens = len(special_tokens)
tokens = vocab if insert_special_tokens else vocab[num_special_tokens:]
if self.mask_token in tokens:
raise ValueError("Reserved mask token %s was found in the passed "
"vocabulary at index %s. Please either remove the "
"reserved token from the vocabulary or change the "
"mask token for this layer." %
(self.mask_token, tokens.index(self.mask_token)))
if insert_special_tokens:
total_vocab_size = len(vocab) + num_special_tokens
else:
total_vocab_size = len(vocab)
if self.max_tokens is not None and total_vocab_size > self.max_tokens:
raise ValueError(
"Attempted to set a vocabulary larger than the maximum vocab size. "
"Passed vocab size is %s, max vocab size is %s." %
(total_vocab_size, self.max_tokens))
start_index = num_special_tokens if insert_special_tokens else 0
values = np.arange(start_index, len(vocab) + start_index, dtype=np.int64)
self._table_handler.clear()
self._table_handler.insert(values, vocab)
if insert_special_tokens and num_special_tokens > 0:
special_token_values = np.arange(num_special_tokens, dtype=np.int64)
self._table_handler.insert(special_token_values, special_tokens)
def set_vocabulary(self, vocab):
"""Sets vocabulary data for this layer with inverse=False.
This method sets the vocabulary for this layer directly, instead of
analyzing a dataset through 'adapt'. It should be used whenever the vocab
information is already known. If vocabulary data is already present in the
layer, this method will either replace it
Arguments:
vocab: An array of string tokens.
Raises:
ValueError: If there are too many inputs, the inputs do not match, or
input data is missing.
"""
if self.invert:
self._set_inverse_vocabulary(vocab)
else:
self._set_forward_vocabulary(vocab)
def _set_state_variables(self, updates):
if not self.built:
raise RuntimeError("_set_state_variables() must be called after build().")

View File

@ -77,6 +77,30 @@ def _get_end_to_end_test_cases():
"input_dtype":
dtypes.string
},
{
"testcase_name":
"test_inverse_strings_soft_vocab_cap",
# Create an array where 'earth' is the most frequent term, followed by
# 'wind', then 'and', then 'fire'. This ensures that the vocab
# accumulator is sorting by frequency.
"vocab_data":
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
"input_data": np.array([[1], [2], [3], [4], [4], [3], [1], [5]]),
"kwargs": {
"max_tokens": None,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"dtype": dtypes.string,
"invert": True
},
"expected_output":
np.array([[b"earth"], [b"wind"], [b"and"], [b"fire"], [b"fire"],
[b"and"], [b"earth"], [b"[OOV]"]]),
"input_dtype":
dtypes.int64
},
{
"testcase_name":
"test_ints_soft_vocab_cap",
@ -125,7 +149,11 @@ class IndexLookupLayerTest(keras_parameterized.TestCase,
use_dataset, expected_output,
input_dtype):
cls = get_layer_class()
expected_output_dtype = dtypes.int64
if "invert" in kwargs and kwargs["invert"]:
expected_output_dtype = kwargs["dtype"]
else:
expected_output_dtype = dtypes.int64
input_shape = input_data.shape
if use_dataset:
@ -156,7 +184,10 @@ class IndexLookupLayerTest(keras_parameterized.TestCase,
expected_output_dtype=expected_output_dtype,
validate_training=False,
adapt_data=vocab_data)
self.assertAllClose(expected_output, output_data)
if "invert" in kwargs and kwargs["invert"]:
self.assertAllEqual(expected_output, output_data)
else:
self.assertAllClose(expected_output, output_data)
@keras_parameterized.run_all_keras_modes
@ -748,6 +779,118 @@ class IndexLookupVocabularyTest(keras_parameterized.TestCase,
layer.set_vocabulary(vocab_data)
@keras_parameterized.run_all_keras_modes
class IndexLookupInverseVocabularyTest(
keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_int_output_explicit_vocab(self):
vocab_data = ["[OOV]", "earth", "wind", "and", "fire"]
input_array = np.array([[2, 3, 4, 5], [5, 4, 2, 1]])
expected_output = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "[OOV]"]])
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
layer = get_layer_class()(
vocabulary=vocab_data,
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
dtype=dtypes.string,
invert=True)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_vocab_with_max_cap(self):
vocab_data = ["", "[OOV]", "wind", "and", "fire"]
layer = get_layer_class()(
max_tokens=5,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
dtype=dtypes.string,
invert=True)
layer.set_vocabulary(vocab_data)
returned_vocab = layer.get_vocabulary()
self.assertAllEqual(vocab_data, returned_vocab)
def test_int_vocab_with_max_cap(self):
vocab_data = [0, -1, 42, 1276, 1138]
layer = get_layer_class()(
max_tokens=5,
num_oov_indices=1,
mask_token=0,
oov_token=-1,
dtype=dtypes.int64,
invert=True)
layer.set_vocabulary(vocab_data)
returned_vocab = layer.get_vocabulary()
self.assertAllEqual(vocab_data, returned_vocab)
def test_non_unique_vocab_fails(self):
vocab_data = ["earth", "wind", "and", "fire", "fire"]
with self.assertRaisesRegex(ValueError, ".*repeated term.*fire.*"):
_ = get_layer_class()(
vocabulary=vocab_data,
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
dtype=dtypes.string,
invert=True)
def test_vocab_with_repeated_element_fails(self):
vocab_data = ["earth", "earth", "wind", "and", "fire"]
layer = get_layer_class()(
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
dtype=dtypes.string,
invert=True)
with self.assertRaisesRegex(ValueError, ".*repeated term.*earth.*"):
layer.set_vocabulary(vocab_data)
def test_vocab_with_reserved_mask_element_fails(self):
vocab_data = ["earth", "mask_token", "wind", "and", "fire"]
layer = get_layer_class()(
max_tokens=None,
num_oov_indices=1,
mask_token="mask_token",
oov_token="[OOV]",
dtype=dtypes.string,
invert=True)
with self.assertRaisesRegex(ValueError, ".*Reserved mask.*"):
layer.set_vocabulary(vocab_data)
def test_non_unique_int_vocab_fails(self):
vocab_data = [12, 13, 14, 15, 15]
with self.assertRaisesRegex(ValueError, ".*repeated term.*15.*"):
_ = get_layer_class()(
vocabulary=vocab_data,
max_tokens=None,
num_oov_indices=1,
mask_token=0,
oov_token=-1,
dtype=dtypes.int64,
invert=True)
def test_int_vocab_with_repeated_element_fails(self):
vocab_data = [11, 11, 34, 23, 124]
layer = get_layer_class()(
max_tokens=None,
num_oov_indices=1,
mask_token=0,
oov_token=-1,
dtype=dtypes.int64,
invert=True)
with self.assertRaisesRegex(ValueError, ".*repeated term.*11.*"):
layer.set_vocabulary(vocab_data)
@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
class IndexLookupSaveableTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):

View File

@ -57,6 +57,8 @@ class IntegerLookup(index_lookup.IndexLookup):
a vocabulary to load into this layer. The file should contain one value
per line. If the list or file contains the same token multiple times, an
error will be thrown.
invert: If true, this layer will map indices to vocabulary items instead
of mapping vocabulary items to indices.
"""
def __init__(self,
@ -65,6 +67,7 @@ class IntegerLookup(index_lookup.IndexLookup):
mask_value=0,
oov_value=-1,
vocabulary=None,
invert=False,
**kwargs):
allowed_dtypes = [dtypes.int64]
@ -95,6 +98,7 @@ class IntegerLookup(index_lookup.IndexLookup):
mask_token=mask_value,
oov_token=oov_value,
vocabulary=vocabulary,
invert=invert,
**kwargs)
def get_config(self):

View File

@ -347,6 +347,36 @@ class IntegerLookupOutputTest(keras_parameterized.TestCase,
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_inverse_output(self):
vocab_data = [0, -1, 42, 1138, 725, 1729]
input_array = np.array([[2, 3, 4, 5], [5, 4, 2, 1]])
expected_output = np.array([[42, 1138, 725, 1729], [1729, 725, 42, -1]])
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
layer = get_layer_class()(invert=True)
layer.set_vocabulary(vocab_data)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_forward_backward_output(self):
vocab_data = [42, 1138, 725, 1729]
input_array = np.array([[42, 1138, 725, 1729], [1729, 725, 42, 203]])
expected_output = np.array([[42, 1138, 725, 1729], [1729, 725, 42, -1]])
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
layer = get_layer_class()()
inverse_layer = get_layer_class()()
layer.set_vocabulary(vocab_data)
inverse_layer = get_layer_class()(
vocabulary=layer.get_vocabulary(), invert=True)
int_data = layer(input_data)
inverse_data = inverse_layer(int_data)
model = keras.Model(inputs=input_data, outputs=inverse_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
@keras_parameterized.run_all_keras_modes
class IntegerLookupVocabularyTest(

View File

@ -58,6 +58,8 @@ class StringLookup(index_lookup.IndexLookup):
one token per line. If the list or file contains the same token multiple
times, an error will be thrown.
encoding: The Python string encoding to use. Defaults to `'utf-8'`.
invert: If true, this layer will map indices to vocabulary items instead
of mapping vocabulary items to indices.
"""
def __init__(self,
@ -67,6 +69,7 @@ class StringLookup(index_lookup.IndexLookup):
oov_token="[OOV]",
vocabulary=None,
encoding="utf-8",
invert=False,
**kwargs):
allowed_dtypes = [dtypes.string]
@ -89,6 +92,7 @@ class StringLookup(index_lookup.IndexLookup):
mask_token=mask_token,
oov_token=oov_token,
vocabulary=vocabulary,
invert=invert,
**kwargs)
def get_config(self):

View File

@ -187,6 +187,36 @@ class StringLookupVocabularyTest(keras_parameterized.TestCase,
with self.assertRaisesRegex(ValueError, ".*repeated term.*earth.*"):
_ = get_layer_class()(vocabulary=vocab_path)
def test_inverse_layer(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([[1, 2, 3, 4], [4, 3, 1, 0]])
expected_output = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", ""]])
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
layer = get_layer_class()(vocabulary=vocab_data, invert=True)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_forward_backward_layer(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
expected_output = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "[OOV]"]])
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = get_layer_class()(vocabulary=vocab_data)
invert_layer = get_layer_class()(
vocabulary=layer.get_vocabulary(), invert=True)
int_data = layer(input_data)
out_data = invert_layer(int_data)
model = keras.Model(inputs=input_data, outputs=out_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
class StringLookupSaveableTest(keras_parameterized.TestCase,

View File

@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
@ -44,6 +45,14 @@ from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
def string_test(actual, expected):
np.testing.assert_array_equal(actual, expected)
def numeric_test(actual, expected):
np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6)
def get_test_data(train_samples,
test_samples,
input_shape,
@ -132,6 +141,11 @@ def layer_test(layer_cls,
if expected_output_dtype is None:
expected_output_dtype = input_dtype
if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
assert_equal = string_test
else:
assert_equal = numeric_test
# instantiation
kwargs = kwargs or {}
layer = layer_cls(**kwargs)
@ -199,8 +213,7 @@ def layer_test(layer_cls,
(layer_cls.__name__, x, actual_output.dtype,
computed_output_signature.dtype, kwargs))
if expected_output is not None:
np.testing.assert_allclose(actual_output, expected_output,
rtol=1e-3, atol=1e-6)
assert_equal(actual_output, expected_output)
# test serialization, weight setting at model level
model_config = model.get_config()
@ -209,7 +222,7 @@ def layer_test(layer_cls,
weights = model.get_weights()
recovered_model.set_weights(weights)
output = recovered_model.predict(input_data)
np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)
assert_equal(output, actual_output)
# test training mode (e.g. useful for dropout tests)
# Rebuild the model to avoid the graph being reused between predict() and
@ -254,8 +267,7 @@ def layer_test(layer_cls,
computed_output_shape,
kwargs))
if expected_output is not None:
np.testing.assert_allclose(actual_output, expected_output,
rtol=1e-3, atol=1e-6)
assert_equal(actual_output, expected_output)
# test serialization, weight setting at model level
model_config = model.get_config()
@ -264,7 +276,7 @@ def layer_test(layer_cls,
weights = model.get_weights()
recovered_model.set_weights(weights)
output = recovered_model.predict(input_data)
np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)
assert_equal(output, actual_output)
# for further checks in the caller function
return actual_output