Add an 'invert' arg to lookup layers.
PiperOrigin-RevId: 312329926 Change-Id: If00e4f169412d7b8e5ebc2b74dae65ade4b0fd0a
This commit is contained in:
parent
34a68f2752
commit
c12107003b
|
@ -75,6 +75,8 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||
only used when performing an inverse lookup.
|
||||
vocabulary: An optional list of vocabulary terms. If the list contains the
|
||||
same token multiple times, an error will be thrown.
|
||||
invert: If true, this layer will map indices to vocabulary items instead
|
||||
of mapping vocabulary items to indices.
|
||||
"""
|
||||
# TODO(momernick): Add an examples section to the docstring.
|
||||
|
||||
|
@ -84,17 +86,22 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||
mask_token,
|
||||
oov_token,
|
||||
vocabulary=None,
|
||||
invert=False,
|
||||
**kwargs):
|
||||
|
||||
# If max_tokens is set, the value must be greater than 1 - otherwise we
|
||||
# are creating a 0-element vocab, which doesn't make sense.
|
||||
if max_tokens is not None and max_tokens <= 1:
|
||||
raise ValueError("If set, max_tokens must be greater than 1.")
|
||||
raise ValueError("If set, `max_tokens` must be greater than 1.")
|
||||
|
||||
if num_oov_indices < 0:
|
||||
raise ValueError("num_oov_indices must be greater than 0. You passed %s" %
|
||||
num_oov_indices)
|
||||
raise ValueError("`num_oov_indices` must be greater than 0. You passed "
|
||||
"%s" % num_oov_indices)
|
||||
|
||||
if invert and num_oov_indices != 1:
|
||||
raise ValueError("`num_oov_tokens` must be 1 when `invert` is True.")
|
||||
|
||||
self.invert = invert
|
||||
self.max_tokens = max_tokens
|
||||
self.num_oov_indices = num_oov_indices
|
||||
self.oov_token = oov_token
|
||||
|
@ -117,10 +124,19 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||
|
||||
self._output_dtype = dtypes.int64
|
||||
|
||||
if invert:
|
||||
key_dtype = self._output_dtype
|
||||
value_dtype = self.dtype
|
||||
oov_value = self.oov_token
|
||||
else:
|
||||
key_dtype = self.dtype
|
||||
value_dtype = self._output_dtype
|
||||
oov_value = self._oov_value
|
||||
|
||||
self._table = lookup_ops.MutableHashTable(
|
||||
key_dtype=self.dtype,
|
||||
value_dtype=self._output_dtype,
|
||||
default_value=self._oov_value,
|
||||
key_dtype=key_dtype,
|
||||
value_dtype=value_dtype,
|
||||
default_value=oov_value,
|
||||
name=(self._name + "_index_table"))
|
||||
tracked_table = self._add_trackable(self._table, trainable=False)
|
||||
# This is a workaround for summary() on this layer. Because the table is
|
||||
|
@ -149,7 +165,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||
|
||||
def compute_output_signature(self, input_spec):
|
||||
output_shape = self.compute_output_shape(input_spec.shape.as_list())
|
||||
output_dtype = dtypes.int64
|
||||
output_dtype = self.dtype if self.invert else self._output_dtype
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
||||
|
||||
def adapt(self, data, reset_state=True):
|
||||
|
@ -176,13 +192,18 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||
keys, values = self._table_handler.data()
|
||||
# This is required because the MutableHashTable doesn't preserve insertion
|
||||
# order, but we rely on the order of the array to assign indices.
|
||||
return [x for _, x in sorted(zip(values, keys))]
|
||||
if self.invert:
|
||||
# If we are inverting, the vocabulary is in the values instead of keys.
|
||||
return [x for _, x in sorted(zip(keys, values))]
|
||||
else:
|
||||
return [x for _, x in sorted(zip(values, keys))]
|
||||
|
||||
def vocab_size(self):
|
||||
return self._table_handler.vocab_size()
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"invert": self.invert,
|
||||
"max_tokens": self.max_tokens,
|
||||
"num_oov_indices": self.num_oov_indices,
|
||||
"oov_token": self.oov_token,
|
||||
|
@ -198,33 +219,15 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||
# abstraction for ease of saving!) we return 0.
|
||||
return 0
|
||||
|
||||
def set_vocabulary(self, vocab):
|
||||
"""Sets vocabulary (and optionally document frequency) data for this layer.
|
||||
|
||||
This method sets the vocabulary for this layer directly, instead of
|
||||
analyzing a dataset through 'adapt'. It should be used whenever the vocab
|
||||
information is already known. If vocabulary data is already present in the
|
||||
layer, this method will either replace it
|
||||
|
||||
Arguments:
|
||||
vocab: An array of string tokens.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are too many inputs, the inputs do not match, or
|
||||
input data is missing.
|
||||
"""
|
||||
|
||||
def _set_forward_vocabulary(self, vocab):
|
||||
"""Sets vocabulary data for this layer when inverse is False."""
|
||||
table_utils.validate_vocabulary_is_unique(vocab)
|
||||
|
||||
should_have_mask = self.mask_token is not None
|
||||
if should_have_mask:
|
||||
has_mask = vocab[0] == self.mask_token
|
||||
oov_start = 1
|
||||
else:
|
||||
has_mask = False
|
||||
oov_start = 0
|
||||
has_mask = vocab[0] == self.mask_token
|
||||
oov_start = 1 if should_have_mask else 0
|
||||
|
||||
should_have_oov = self.num_oov_indices > 0
|
||||
should_have_oov = (self.num_oov_indices > 0) and not self.invert
|
||||
if should_have_oov:
|
||||
oov_end = oov_start + self.num_oov_indices
|
||||
expected_oov = [self.oov_token] * self.num_oov_indices
|
||||
|
@ -293,6 +296,65 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||
special_token_values = np.arange(num_special_tokens, dtype=np.int64)
|
||||
self._table_handler.insert(special_tokens, special_token_values)
|
||||
|
||||
def _set_inverse_vocabulary(self, vocab):
|
||||
"""Sets vocabulary data for this layer when inverse is True."""
|
||||
table_utils.validate_vocabulary_is_unique(vocab)
|
||||
|
||||
should_have_mask = self.mask_token is not None
|
||||
has_mask = vocab[0] == self.mask_token
|
||||
|
||||
insert_special_tokens = should_have_mask and not has_mask
|
||||
special_tokens = [] if self.mask_token is None else [self.mask_token]
|
||||
|
||||
num_special_tokens = len(special_tokens)
|
||||
tokens = vocab if insert_special_tokens else vocab[num_special_tokens:]
|
||||
if self.mask_token in tokens:
|
||||
raise ValueError("Reserved mask token %s was found in the passed "
|
||||
"vocabulary at index %s. Please either remove the "
|
||||
"reserved token from the vocabulary or change the "
|
||||
"mask token for this layer." %
|
||||
(self.mask_token, tokens.index(self.mask_token)))
|
||||
|
||||
if insert_special_tokens:
|
||||
total_vocab_size = len(vocab) + num_special_tokens
|
||||
else:
|
||||
total_vocab_size = len(vocab)
|
||||
if self.max_tokens is not None and total_vocab_size > self.max_tokens:
|
||||
raise ValueError(
|
||||
"Attempted to set a vocabulary larger than the maximum vocab size. "
|
||||
"Passed vocab size is %s, max vocab size is %s." %
|
||||
(total_vocab_size, self.max_tokens))
|
||||
|
||||
start_index = num_special_tokens if insert_special_tokens else 0
|
||||
values = np.arange(start_index, len(vocab) + start_index, dtype=np.int64)
|
||||
|
||||
self._table_handler.clear()
|
||||
self._table_handler.insert(values, vocab)
|
||||
|
||||
if insert_special_tokens and num_special_tokens > 0:
|
||||
special_token_values = np.arange(num_special_tokens, dtype=np.int64)
|
||||
self._table_handler.insert(special_token_values, special_tokens)
|
||||
|
||||
def set_vocabulary(self, vocab):
|
||||
"""Sets vocabulary data for this layer with inverse=False.
|
||||
|
||||
This method sets the vocabulary for this layer directly, instead of
|
||||
analyzing a dataset through 'adapt'. It should be used whenever the vocab
|
||||
information is already known. If vocabulary data is already present in the
|
||||
layer, this method will either replace it
|
||||
|
||||
Arguments:
|
||||
vocab: An array of string tokens.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are too many inputs, the inputs do not match, or
|
||||
input data is missing.
|
||||
"""
|
||||
if self.invert:
|
||||
self._set_inverse_vocabulary(vocab)
|
||||
else:
|
||||
self._set_forward_vocabulary(vocab)
|
||||
|
||||
def _set_state_variables(self, updates):
|
||||
if not self.built:
|
||||
raise RuntimeError("_set_state_variables() must be called after build().")
|
||||
|
|
|
@ -77,6 +77,30 @@ def _get_end_to_end_test_cases():
|
|||
"input_dtype":
|
||||
dtypes.string
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"test_inverse_strings_soft_vocab_cap",
|
||||
# Create an array where 'earth' is the most frequent term, followed by
|
||||
# 'wind', then 'and', then 'fire'. This ensures that the vocab
|
||||
# accumulator is sorting by frequency.
|
||||
"vocab_data":
|
||||
np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
|
||||
["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
|
||||
"input_data": np.array([[1], [2], [3], [4], [4], [3], [1], [5]]),
|
||||
"kwargs": {
|
||||
"max_tokens": None,
|
||||
"num_oov_indices": 1,
|
||||
"mask_token": "",
|
||||
"oov_token": "[OOV]",
|
||||
"dtype": dtypes.string,
|
||||
"invert": True
|
||||
},
|
||||
"expected_output":
|
||||
np.array([[b"earth"], [b"wind"], [b"and"], [b"fire"], [b"fire"],
|
||||
[b"and"], [b"earth"], [b"[OOV]"]]),
|
||||
"input_dtype":
|
||||
dtypes.int64
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"test_ints_soft_vocab_cap",
|
||||
|
@ -125,7 +149,11 @@ class IndexLookupLayerTest(keras_parameterized.TestCase,
|
|||
use_dataset, expected_output,
|
||||
input_dtype):
|
||||
cls = get_layer_class()
|
||||
expected_output_dtype = dtypes.int64
|
||||
if "invert" in kwargs and kwargs["invert"]:
|
||||
expected_output_dtype = kwargs["dtype"]
|
||||
else:
|
||||
expected_output_dtype = dtypes.int64
|
||||
|
||||
input_shape = input_data.shape
|
||||
|
||||
if use_dataset:
|
||||
|
@ -156,7 +184,10 @@ class IndexLookupLayerTest(keras_parameterized.TestCase,
|
|||
expected_output_dtype=expected_output_dtype,
|
||||
validate_training=False,
|
||||
adapt_data=vocab_data)
|
||||
self.assertAllClose(expected_output, output_data)
|
||||
if "invert" in kwargs and kwargs["invert"]:
|
||||
self.assertAllEqual(expected_output, output_data)
|
||||
else:
|
||||
self.assertAllClose(expected_output, output_data)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
|
@ -748,6 +779,118 @@ class IndexLookupVocabularyTest(keras_parameterized.TestCase,
|
|||
layer.set_vocabulary(vocab_data)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class IndexLookupInverseVocabularyTest(
|
||||
keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def test_int_output_explicit_vocab(self):
|
||||
vocab_data = ["[OOV]", "earth", "wind", "and", "fire"]
|
||||
input_array = np.array([[2, 3, 4, 5], [5, 4, 2, 1]])
|
||||
expected_output = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "[OOV]"]])
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
|
||||
layer = get_layer_class()(
|
||||
vocabulary=vocab_data,
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
dtype=dtypes.string,
|
||||
invert=True)
|
||||
int_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_vocab_with_max_cap(self):
|
||||
vocab_data = ["", "[OOV]", "wind", "and", "fire"]
|
||||
layer = get_layer_class()(
|
||||
max_tokens=5,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
dtype=dtypes.string,
|
||||
invert=True)
|
||||
layer.set_vocabulary(vocab_data)
|
||||
returned_vocab = layer.get_vocabulary()
|
||||
self.assertAllEqual(vocab_data, returned_vocab)
|
||||
|
||||
def test_int_vocab_with_max_cap(self):
|
||||
vocab_data = [0, -1, 42, 1276, 1138]
|
||||
layer = get_layer_class()(
|
||||
max_tokens=5,
|
||||
num_oov_indices=1,
|
||||
mask_token=0,
|
||||
oov_token=-1,
|
||||
dtype=dtypes.int64,
|
||||
invert=True)
|
||||
layer.set_vocabulary(vocab_data)
|
||||
returned_vocab = layer.get_vocabulary()
|
||||
self.assertAllEqual(vocab_data, returned_vocab)
|
||||
|
||||
def test_non_unique_vocab_fails(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire", "fire"]
|
||||
with self.assertRaisesRegex(ValueError, ".*repeated term.*fire.*"):
|
||||
_ = get_layer_class()(
|
||||
vocabulary=vocab_data,
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
dtype=dtypes.string,
|
||||
invert=True)
|
||||
|
||||
def test_vocab_with_repeated_element_fails(self):
|
||||
vocab_data = ["earth", "earth", "wind", "and", "fire"]
|
||||
layer = get_layer_class()(
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="",
|
||||
oov_token="[OOV]",
|
||||
dtype=dtypes.string,
|
||||
invert=True)
|
||||
with self.assertRaisesRegex(ValueError, ".*repeated term.*earth.*"):
|
||||
layer.set_vocabulary(vocab_data)
|
||||
|
||||
def test_vocab_with_reserved_mask_element_fails(self):
|
||||
vocab_data = ["earth", "mask_token", "wind", "and", "fire"]
|
||||
layer = get_layer_class()(
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token="mask_token",
|
||||
oov_token="[OOV]",
|
||||
dtype=dtypes.string,
|
||||
invert=True)
|
||||
with self.assertRaisesRegex(ValueError, ".*Reserved mask.*"):
|
||||
layer.set_vocabulary(vocab_data)
|
||||
|
||||
def test_non_unique_int_vocab_fails(self):
|
||||
vocab_data = [12, 13, 14, 15, 15]
|
||||
with self.assertRaisesRegex(ValueError, ".*repeated term.*15.*"):
|
||||
_ = get_layer_class()(
|
||||
vocabulary=vocab_data,
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token=0,
|
||||
oov_token=-1,
|
||||
dtype=dtypes.int64,
|
||||
invert=True)
|
||||
|
||||
def test_int_vocab_with_repeated_element_fails(self):
|
||||
vocab_data = [11, 11, 34, 23, 124]
|
||||
layer = get_layer_class()(
|
||||
max_tokens=None,
|
||||
num_oov_indices=1,
|
||||
mask_token=0,
|
||||
oov_token=-1,
|
||||
dtype=dtypes.int64,
|
||||
invert=True)
|
||||
with self.assertRaisesRegex(ValueError, ".*repeated term.*11.*"):
|
||||
layer.set_vocabulary(vocab_data)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
|
||||
class IndexLookupSaveableTest(keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
|
|
@ -57,6 +57,8 @@ class IntegerLookup(index_lookup.IndexLookup):
|
|||
a vocabulary to load into this layer. The file should contain one value
|
||||
per line. If the list or file contains the same token multiple times, an
|
||||
error will be thrown.
|
||||
invert: If true, this layer will map indices to vocabulary items instead
|
||||
of mapping vocabulary items to indices.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -65,6 +67,7 @@ class IntegerLookup(index_lookup.IndexLookup):
|
|||
mask_value=0,
|
||||
oov_value=-1,
|
||||
vocabulary=None,
|
||||
invert=False,
|
||||
**kwargs):
|
||||
allowed_dtypes = [dtypes.int64]
|
||||
|
||||
|
@ -95,6 +98,7 @@ class IntegerLookup(index_lookup.IndexLookup):
|
|||
mask_token=mask_value,
|
||||
oov_token=oov_value,
|
||||
vocabulary=vocabulary,
|
||||
invert=invert,
|
||||
**kwargs)
|
||||
|
||||
def get_config(self):
|
||||
|
|
|
@ -347,6 +347,36 @@ class IntegerLookupOutputTest(keras_parameterized.TestCase,
|
|||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_inverse_output(self):
|
||||
vocab_data = [0, -1, 42, 1138, 725, 1729]
|
||||
input_array = np.array([[2, 3, 4, 5], [5, 4, 2, 1]])
|
||||
expected_output = np.array([[42, 1138, 725, 1729], [1729, 725, 42, -1]])
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
|
||||
layer = get_layer_class()(invert=True)
|
||||
layer.set_vocabulary(vocab_data)
|
||||
int_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_forward_backward_output(self):
|
||||
vocab_data = [42, 1138, 725, 1729]
|
||||
input_array = np.array([[42, 1138, 725, 1729], [1729, 725, 42, 203]])
|
||||
expected_output = np.array([[42, 1138, 725, 1729], [1729, 725, 42, -1]])
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
|
||||
layer = get_layer_class()()
|
||||
inverse_layer = get_layer_class()()
|
||||
layer.set_vocabulary(vocab_data)
|
||||
inverse_layer = get_layer_class()(
|
||||
vocabulary=layer.get_vocabulary(), invert=True)
|
||||
int_data = layer(input_data)
|
||||
inverse_data = inverse_layer(int_data)
|
||||
model = keras.Model(inputs=input_data, outputs=inverse_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class IntegerLookupVocabularyTest(
|
||||
|
|
|
@ -58,6 +58,8 @@ class StringLookup(index_lookup.IndexLookup):
|
|||
one token per line. If the list or file contains the same token multiple
|
||||
times, an error will be thrown.
|
||||
encoding: The Python string encoding to use. Defaults to `'utf-8'`.
|
||||
invert: If true, this layer will map indices to vocabulary items instead
|
||||
of mapping vocabulary items to indices.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -67,6 +69,7 @@ class StringLookup(index_lookup.IndexLookup):
|
|||
oov_token="[OOV]",
|
||||
vocabulary=None,
|
||||
encoding="utf-8",
|
||||
invert=False,
|
||||
**kwargs):
|
||||
allowed_dtypes = [dtypes.string]
|
||||
|
||||
|
@ -89,6 +92,7 @@ class StringLookup(index_lookup.IndexLookup):
|
|||
mask_token=mask_token,
|
||||
oov_token=oov_token,
|
||||
vocabulary=vocabulary,
|
||||
invert=invert,
|
||||
**kwargs)
|
||||
|
||||
def get_config(self):
|
||||
|
|
|
@ -187,6 +187,36 @@ class StringLookupVocabularyTest(keras_parameterized.TestCase,
|
|||
with self.assertRaisesRegex(ValueError, ".*repeated term.*earth.*"):
|
||||
_ = get_layer_class()(vocabulary=vocab_path)
|
||||
|
||||
def test_inverse_layer(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([[1, 2, 3, 4], [4, 3, 1, 0]])
|
||||
expected_output = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", ""]])
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.int64)
|
||||
layer = get_layer_class()(vocabulary=vocab_data, invert=True)
|
||||
int_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_forward_backward_layer(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "michigan"]])
|
||||
expected_output = np.array([["earth", "wind", "and", "fire"],
|
||||
["fire", "and", "earth", "[OOV]"]])
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(vocabulary=vocab_data)
|
||||
invert_layer = get_layer_class()(
|
||||
vocabulary=layer.get_vocabulary(), invert=True)
|
||||
int_data = layer(input_data)
|
||||
out_data = invert_layer(int_data)
|
||||
model = keras.Model(inputs=input_data, outputs=out_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
|
||||
class StringLookupSaveableTest(keras_parameterized.TestCase,
|
||||
|
|
|
@ -25,6 +25,7 @@ import numpy as np
|
|||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
|
@ -44,6 +45,14 @@ from tensorflow.python.util import tf_decorator
|
|||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def string_test(actual, expected):
|
||||
np.testing.assert_array_equal(actual, expected)
|
||||
|
||||
|
||||
def numeric_test(actual, expected):
|
||||
np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6)
|
||||
|
||||
|
||||
def get_test_data(train_samples,
|
||||
test_samples,
|
||||
input_shape,
|
||||
|
@ -132,6 +141,11 @@ def layer_test(layer_cls,
|
|||
if expected_output_dtype is None:
|
||||
expected_output_dtype = input_dtype
|
||||
|
||||
if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
|
||||
assert_equal = string_test
|
||||
else:
|
||||
assert_equal = numeric_test
|
||||
|
||||
# instantiation
|
||||
kwargs = kwargs or {}
|
||||
layer = layer_cls(**kwargs)
|
||||
|
@ -199,8 +213,7 @@ def layer_test(layer_cls,
|
|||
(layer_cls.__name__, x, actual_output.dtype,
|
||||
computed_output_signature.dtype, kwargs))
|
||||
if expected_output is not None:
|
||||
np.testing.assert_allclose(actual_output, expected_output,
|
||||
rtol=1e-3, atol=1e-6)
|
||||
assert_equal(actual_output, expected_output)
|
||||
|
||||
# test serialization, weight setting at model level
|
||||
model_config = model.get_config()
|
||||
|
@ -209,7 +222,7 @@ def layer_test(layer_cls,
|
|||
weights = model.get_weights()
|
||||
recovered_model.set_weights(weights)
|
||||
output = recovered_model.predict(input_data)
|
||||
np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)
|
||||
assert_equal(output, actual_output)
|
||||
|
||||
# test training mode (e.g. useful for dropout tests)
|
||||
# Rebuild the model to avoid the graph being reused between predict() and
|
||||
|
@ -254,8 +267,7 @@ def layer_test(layer_cls,
|
|||
computed_output_shape,
|
||||
kwargs))
|
||||
if expected_output is not None:
|
||||
np.testing.assert_allclose(actual_output, expected_output,
|
||||
rtol=1e-3, atol=1e-6)
|
||||
assert_equal(actual_output, expected_output)
|
||||
|
||||
# test serialization, weight setting at model level
|
||||
model_config = model.get_config()
|
||||
|
@ -264,7 +276,7 @@ def layer_test(layer_cls,
|
|||
weights = model.get_weights()
|
||||
recovered_model.set_weights(weights)
|
||||
output = recovered_model.predict(input_data)
|
||||
np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)
|
||||
assert_equal(output, actual_output)
|
||||
|
||||
# for further checks in the caller function
|
||||
return actual_output
|
||||
|
|
Loading…
Reference in New Issue