Move tf.keras.experimental.SequenceFeatures to keras package.

PiperOrigin-RevId: 309421612
Change-Id: I55e17386071dad91cce2e2500f30fc9e3c3cf657
This commit is contained in:
Scott Zhu 2020-05-01 09:41:40 -07:00 committed by TensorFlower Gardener
parent a952fa1b1c
commit c53eca0d6a
17 changed files with 1203 additions and 838 deletions

View File

@ -13,6 +13,7 @@ py_library(
":feature_column",
":feature_column_v2",
"//tensorflow/python:util",
"//tensorflow/python/keras/feature_column",
],
)

View File

@ -27,4 +27,5 @@ from tensorflow.python.feature_column.feature_column import *
from tensorflow.python.feature_column.feature_column_v2 import *
from tensorflow.python.feature_column.sequence_feature_column import *
from tensorflow.python.feature_column.serialization import *
from tensorflow.python.keras.feature_column.sequence_feature_column import *
# pylint: enable=unused-import,line-too-long

View File

@ -316,14 +316,6 @@ class FeatureColumnsIntegrationTest(keras_parameterized.TestCase):
self.assertIsInstance(revived, fc.DenseFeatures)
self.assertNotIsInstance(revived, dense_features_v2.DenseFeatures)
def test_serialization_sequence_features(self):
rating = fc.sequence_numeric_column('rating')
sequence_feature = fc.SequenceFeatures([rating])
config = keras.layers.serialize(sequence_feature)
revived = keras.layers.deserialize(config)
self.assertIsInstance(revived, fc.SequenceFeatures)
# This test is an example for a regression on categorical inputs, i.e.,
# the output is 0.4, 0.6, 0.9 when input is 'alpha', 'beta', 'gamma'
# separately.

View File

@ -30,156 +30,14 @@ from tensorflow.python.feature_column import utils as fc_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.layers import serialization as layer_serialization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@keras_export('keras.experimental.SequenceFeatures')
class SequenceFeatures(fc._BaseFeaturesLayer):
"""A layer for sequence input.
All `feature_columns` must be sequence dense columns with the same
`sequence_length`. The output of this method can be fed into sequence
networks, such as RNN.
The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`.
`T` is the maximum sequence length for this batch, which could differ from
batch to batch.
If multiple `feature_columns` are given with `Di` `num_elements` each, their
outputs are concatenated. So, the final `Tensor` has shape
`[batch_size, T, D0 + D1 + ... + Dn]`.
Example:
```python
# Behavior of some cells or feature columns may depend on whether we are in
# training or inference mode, e.g. applying dropout.
training = True
rating = sequence_numeric_column('rating')
watches = sequence_categorical_column_with_identity(
'watches', num_buckets=1000)
watches_embedding = embedding_column(watches, dimension=10)
columns = [rating, watches_embedding]
sequence_input_layer = SequenceFeatures(columns)
features = tf.io.parse_example(...,
features=make_parse_example_spec(columns))
sequence_input, sequence_length = sequence_input_layer(
features, training=training)
sequence_length_mask = tf.sequence_mask(sequence_length)
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size, training=training)
rnn_layer = tf.keras.layers.RNN(rnn_cell, training=training)
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
```
"""
def __init__(
self,
feature_columns,
trainable=True,
name=None,
**kwargs):
""""Constructs a SequenceFeatures layer.
Args:
feature_columns: An iterable of dense sequence columns. Valid columns are
- `embedding_column` that wraps a `sequence_categorical_column_with_*`
- `sequence_numeric_column`.
trainable: Boolean, whether the layer's variables will be updated via
gradient descent during training.
name: Name to give to the SequenceFeatures.
**kwargs: Keyword arguments to construct a layer.
Raises:
ValueError: If any of the `feature_columns` is not a
`SequenceDenseColumn`.
"""
super(SequenceFeatures, self).__init__(
feature_columns=feature_columns,
trainable=trainable,
name=name,
expected_column_type=fc.SequenceDenseColumn,
**kwargs)
@property
def _is_feature_layer(self):
return True
def _target_shape(self, input_shape, total_elements):
return (input_shape[0], input_shape[1], total_elements)
def call(self, features, training=None):
"""Returns sequence input corresponding to the `feature_columns`.
Args:
features: A dict mapping keys to tensors.
training: Python boolean or None, indicating whether to the layer is being
run in training mode. This argument is passed to the call method of any
`FeatureColumn` that takes a `training` argument. For example, if a
`FeatureColumn` performed dropout, the column could expose a `training`
argument to control whether the dropout should be applied. If `None`,
defaults to `tf.keras.backend.learning_phase()`.
Returns:
An `(input_layer, sequence_length)` tuple where:
- input_layer: A float `Tensor` of shape `[batch_size, T, D]`.
`T` is the maximum sequence length for this batch, which could differ
from batch to batch. `D` is the sum of `num_elements` for all
`feature_columns`.
- sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence
length for each example.
Raises:
ValueError: If features are not a dictionary.
"""
if not isinstance(features, dict):
raise ValueError('We expected a dictionary here. Instead we got: ',
features)
if training is None:
training = backend.learning_phase()
transformation_cache = fc.FeatureTransformationCache(features)
output_tensors = []
sequence_lengths = []
for column in self._feature_columns:
with ops.name_scope(column.name):
try:
dense_tensor, sequence_length = column.get_sequence_dense_tensor(
transformation_cache, self._state_manager, training=training)
except TypeError:
dense_tensor, sequence_length = column.get_sequence_dense_tensor(
transformation_cache, self._state_manager)
# Flattens the final dimension to produce a 3D Tensor.
output_tensors.append(self._process_dense_tensor(column, dense_tensor))
sequence_lengths.append(sequence_length)
# Check and process sequence lengths.
fc._verify_static_batch_size_equality(sequence_lengths,
self._feature_columns)
sequence_length = _assert_all_equal_and_return(sequence_lengths)
return self._verify_and_concat_tensors(output_tensors), sequence_length
layer_serialization.inject_feature_column_v1_objects(
'SequenceFeatures', SequenceFeatures)
layer_serialization.inject_feature_column_v2_objects(
'SequenceFeatures', SequenceFeatures)
def concatenate_context_input(context_input, sequence_input):
"""Replicates `context_input` across all timesteps of `sequence_input`.

View File

@ -24,130 +24,13 @@ import tempfile
from google.protobuf import text_format
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.feature_column import dense_features
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.feature_column import sequence_feature_column as sfc
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class SequenceFeatureColumnIntegrationTest(test.TestCase):
def _make_sequence_example(self):
example = example_pb2.SequenceExample()
example.context.feature['int_ctx'].int64_list.value.extend([5])
example.context.feature['float_ctx'].float_list.value.extend([123.6])
for val in range(0, 10, 2):
feat = feature_pb2.Feature()
feat.int64_list.value.extend([val] * val)
example.feature_lists.feature_list['int_list'].feature.extend([feat])
for val in range(1, 11, 2):
feat = feature_pb2.Feature()
feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val)
example.feature_lists.feature_list['str_list'].feature.extend([feat])
return example
def _build_feature_columns(self):
col = fc.categorical_column_with_identity('int_ctx', num_buckets=100)
ctx_cols = [
fc.embedding_column(col, dimension=10),
fc.numeric_column('float_ctx')
]
identity_col = sfc.sequence_categorical_column_with_identity(
'int_list', num_buckets=10)
bucket_col = sfc.sequence_categorical_column_with_hash_bucket(
'bytes_list', hash_bucket_size=100)
seq_cols = [
fc.embedding_column(identity_col, dimension=10),
fc.embedding_column(bucket_col, dimension=20)
]
return ctx_cols, seq_cols
def test_sequence_example_into_input_layer(self):
examples = [_make_sequence_example().SerializeToString()] * 100
ctx_cols, seq_cols = self._build_feature_columns()
def _parse_example(example):
ctx, seq = parsing_ops.parse_single_sequence_example(
example,
context_features=fc.make_parse_example_spec_v2(ctx_cols),
sequence_features=fc.make_parse_example_spec_v2(seq_cols))
ctx.update(seq)
return ctx
ds = dataset_ops.Dataset.from_tensor_slices(examples)
ds = ds.map(_parse_example)
ds = ds.batch(20)
# Test on a single batch
features = dataset_ops.make_one_shot_iterator(ds).get_next()
# Tile the context features across the sequence features
sequence_input_layer = sfc.SequenceFeatures(seq_cols)
seq_layer, _ = sequence_input_layer(features)
input_layer = dense_features.DenseFeatures(ctx_cols)
ctx_layer = input_layer(features)
input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer)
rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10))
output = rnn_layer(input_layer)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
features_r = sess.run(features)
self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6])
output_r = sess.run(output)
self.assertAllEqual(output_r.shape, [20, 10])
@test_util.run_deprecated_v1
def test_shared_sequence_non_sequence_into_input_layer(self):
non_seq = fc.categorical_column_with_identity('non_seq',
num_buckets=10)
seq = sfc.sequence_categorical_column_with_identity('seq',
num_buckets=10)
shared_non_seq, shared_seq = fc.shared_embedding_columns_v2(
[non_seq, seq],
dimension=4,
combiner='sum',
initializer=init_ops_v2.Ones(),
shared_embedding_collection_name='shared')
seq = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=[0, 1, 2],
dense_shape=[2, 2])
non_seq = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=[0, 1, 2],
dense_shape=[2, 2])
features = {'seq': seq, 'non_seq': non_seq}
# Tile the context features across the sequence features
seq_input, seq_length = sfc.SequenceFeatures([shared_seq])(features)
non_seq_input = dense_features.DenseFeatures([shared_non_seq])(features)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output_seq, output_seq_length, output_non_seq = sess.run(
[seq_input, seq_length, non_seq_input])
self.assertAllEqual(output_seq, [[[1, 1, 1, 1], [1, 1, 1, 1]],
[[1, 1, 1, 1], [0, 0, 0, 0]]])
self.assertAllEqual(output_seq_length, [2, 1])
self.assertAllEqual(output_non_seq, [[2, 2, 2, 2], [1, 1, 1, 1]])
class SequenceExampleParsingTest(test.TestCase):
def test_seq_ex_in_sequence_categorical_column_with_identity(self):

View File

@ -29,7 +29,6 @@ from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.feature_column import sequence_feature_column as sfc
from tensorflow.python.feature_column import serialization
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
@ -49,538 +48,6 @@ def _initialized_session(config=None):
return sess
class SequenceFeaturesTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args_a': {
# example 0, ids [2]
# example 1, ids [0, 1]
'indices': ((0, 0), (1, 0), (1, 1)),
'values': (2, 0, 1),
'dense_shape': (2, 2)},
'sparse_input_args_b': {
# example 0, ids [1]
# example 1, ids [2, 0]
'indices': ((0, 0), (1, 0), (1, 1)),
'values': (1, 2, 0),
'dense_shape': (2, 2)},
'expected_input_layer': [
# example 0, ids_a [2], ids_b [1]
[[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]],
# example 1, ids_a [0, 1], ids_b [2, 0]
[[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],],
'expected_sequence_length': [1, 2]},
{'testcase_name': '3D',
'sparse_input_args_a': {
# feature 0, ids [[2], [0, 1]]
# feature 1, ids [[0, 0], [1]]
'indices': (
(0, 0, 0), (0, 1, 0), (0, 1, 1),
(1, 0, 0), (1, 0, 1), (1, 1, 0)),
'values': (2, 0, 1, 0, 0, 1),
'dense_shape': (2, 2, 2)},
'sparse_input_args_b': {
# feature 0, ids [[1, 1], [1]]
# feature 1, ids [[2], [0]]
'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
'values': (1, 1, 1, 2, 0),
'dense_shape': (2, 2, 2)},
'expected_input_layer': [
# feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -]
[[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]],
# feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -]
[[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]],
'expected_sequence_length': [2, 2]},
)
@test_util.run_in_graph_and_eager_modes
def test_embedding_column(
self, sparse_input_args_a, sparse_input_args_b, expected_input_layer,
expected_sequence_length):
sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a)
sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b)
vocabulary_size = 3
embedding_dimension_a = 2
embedding_values_a = (
(1., 2.), # id 0
(3., 4.), # id 1
(5., 6.) # id 2
)
embedding_dimension_b = 3
embedding_values_b = (
(11., 12., 13.), # id 0
(14., 15., 16.), # id 1
(17., 18., 19.) # id 2
)
def _get_initializer(embedding_dimension, embedding_values):
def _initializer(shape, dtype, partition_info=None):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
self.assertIsNone(partition_info)
return embedding_values
return _initializer
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column_a = fc.embedding_column(
categorical_column_a,
dimension=embedding_dimension_a,
initializer=_get_initializer(embedding_dimension_a, embedding_values_a))
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
embedding_column_b = fc.embedding_column(
categorical_column_b,
dimension=embedding_dimension_b,
initializer=_get_initializer(embedding_dimension_b, embedding_values_b))
# Test that columns are reordered alphabetically.
sequence_input_layer = sfc.SequenceFeatures(
[embedding_column_b, embedding_column_a])
input_layer, sequence_length = sequence_input_layer({
'aaa': sparse_input_a, 'bbb': sparse_input_b,})
self.evaluate(variables_lib.global_variables_initializer())
weights = sequence_input_layer.weights
self.assertCountEqual(
('sequence_features/aaa_embedding/embedding_weights:0',
'sequence_features/bbb_embedding/embedding_weights:0'),
tuple([v.name for v in weights]))
self.assertAllEqual(embedding_values_a, self.evaluate(weights[0]))
self.assertAllEqual(embedding_values_b, self.evaluate(weights[1]))
self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
self.assertAllEqual(
expected_sequence_length, self.evaluate(sequence_length))
@test_util.run_in_graph_and_eager_modes
def test_embedding_column_with_non_sequence_categorical(self):
"""Tests that error is raised for non-sequence embedding column."""
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column_a = fc.embedding_column(
categorical_column_a, dimension=2)
sequence_input_layer = sfc.SequenceFeatures([embedding_column_a])
with self.assertRaisesRegexp(
ValueError,
r'In embedding_column: aaa_embedding\. categorical_column must be of '
r'type SequenceCategoricalColumn to use SequenceFeatures\.'):
_, _ = sequence_input_layer({'aaa': sparse_input})
@test_util.run_in_graph_and_eager_modes
def test_shared_embedding_column(self):
with ops.Graph().as_default():
vocabulary_size = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
sparse_input_b = sparse_tensor.SparseTensorValue(
# example 0, ids [1]
# example 1, ids [2, 0]
indices=((0, 0), (1, 0), (1, 1)),
values=(1, 2, 0),
dense_shape=(2, 2))
embedding_dimension = 2
embedding_values = (
(1., 2.), # id 0
(3., 4.), # id 1
(5., 6.) # id 2
)
def _get_initializer(embedding_dimension, embedding_values):
def _initializer(shape, dtype, partition_info=None):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
self.assertIsNone(partition_info)
return embedding_values
return _initializer
expected_input_layer = [
# example 0, ids_a [2], ids_b [1]
[[5., 6., 3., 4.], [0., 0., 0., 0.]],
# example 1, ids_a [0, 1], ids_b [2, 0]
[[1., 2., 5., 6.], [3., 4., 1., 2.]],
]
expected_sequence_length = [1, 2]
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
# Test that columns are reordered alphabetically.
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension,
initializer=_get_initializer(embedding_dimension, embedding_values))
sequence_input_layer = sfc.SequenceFeatures(shared_embedding_columns)
input_layer, sequence_length = sequence_input_layer({
'aaa': sparse_input_a, 'bbb': sparse_input_b})
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertCountEqual(
('aaa_bbb_shared_embedding:0',),
tuple([v.name for v in global_vars]))
with _initialized_session() as sess:
self.assertAllEqual(embedding_values,
global_vars[0].eval(session=sess))
self.assertAllEqual(expected_input_layer,
input_layer.eval(session=sess))
self.assertAllEqual(
expected_sequence_length, sequence_length.eval(session=sess))
@test_util.run_deprecated_v1
def test_shared_embedding_column_with_non_sequence_categorical(self):
"""Tests that error is raised for non-sequence shared embedding column."""
vocabulary_size = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
sparse_input_b = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b], dimension=2)
sequence_input_layer = sfc.SequenceFeatures(shared_embedding_columns)
with self.assertRaisesRegexp(
ValueError,
r'In embedding_column: aaa_shared_embedding\. categorical_column must '
r'be of type SequenceCategoricalColumn to use SequenceFeatures\.'):
_, _ = sequence_input_layer({'aaa': sparse_input_a,
'bbb': sparse_input_b})
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args_a': {
# example 0, ids [2]
# example 1, ids [0, 1]
'indices': ((0, 0), (1, 0), (1, 1)),
'values': (2, 0, 1),
'dense_shape': (2, 2)},
'sparse_input_args_b': {
# example 0, ids [1]
# example 1, ids [1, 0]
'indices': ((0, 0), (1, 0), (1, 1)),
'values': (1, 1, 0),
'dense_shape': (2, 2)},
'expected_input_layer': [
# example 0, ids_a [2], ids_b [1]
[[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]],
# example 1, ids_a [0, 1], ids_b [1, 0]
[[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]],
'expected_sequence_length': [1, 2]},
{'testcase_name': '3D',
'sparse_input_args_a': {
# feature 0, ids [[2], [0, 1]]
# feature 1, ids [[0, 0], [1]]
'indices': (
(0, 0, 0), (0, 1, 0), (0, 1, 1),
(1, 0, 0), (1, 0, 1), (1, 1, 0)),
'values': (2, 0, 1, 0, 0, 1),
'dense_shape': (2, 2, 2)},
'sparse_input_args_b': {
# feature 0, ids [[1, 1], [1]]
# feature 1, ids [[1], [0]]
'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
'values': (1, 1, 1, 1, 0),
'dense_shape': (2, 2, 2)},
'expected_input_layer': [
# feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -]
[[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]],
# feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -]
[[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]],
'expected_sequence_length': [2, 2]},
)
@test_util.run_in_graph_and_eager_modes
def test_indicator_column(
self, sparse_input_args_a, sparse_input_args_b, expected_input_layer,
expected_sequence_length):
sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a)
sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b)
vocabulary_size_a = 3
vocabulary_size_b = 2
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size_a)
indicator_column_a = fc.indicator_column(categorical_column_a)
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size_b)
indicator_column_b = fc.indicator_column(categorical_column_b)
# Test that columns are reordered alphabetically.
sequence_input_layer = sfc.SequenceFeatures(
[indicator_column_b, indicator_column_a])
input_layer, sequence_length = sequence_input_layer({
'aaa': sparse_input_a, 'bbb': sparse_input_b})
self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
self.assertAllEqual(
expected_sequence_length, self.evaluate(sequence_length))
@test_util.run_in_graph_and_eager_modes
def test_indicator_column_with_non_sequence_categorical(self):
"""Tests that error is raised for non-sequence categorical column."""
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column_a = fc.indicator_column(categorical_column_a)
sequence_input_layer = sfc.SequenceFeatures([indicator_column_a])
with self.assertRaisesRegexp(
ValueError,
r'In indicator_column: aaa_indicator\. categorical_column must be of '
r'type SequenceCategoricalColumn to use SequenceFeatures\.'):
_, _ = sequence_input_layer({'aaa': sparse_input})
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args': {
# example 0, values [0., 1]
# example 1, [10.]
'indices': ((0, 0), (0, 1), (1, 0)),
'values': (0., 1., 10.),
'dense_shape': (2, 2)},
'expected_input_layer': [
[[0.], [1.]],
[[10.], [0.]]],
'expected_sequence_length': [2, 1]},
{'testcase_name': '3D',
'sparse_input_args': {
# feature 0, ids [[20, 3], [5]]
# feature 1, ids [[3], [8]]
'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
'values': (20., 3., 5., 3., 8.),
'dense_shape': (2, 2, 2)},
'expected_input_layer': [
[[20.], [3.], [5.], [0.]],
[[3.], [0.], [8.], [0.]]],
'expected_sequence_length': [2, 2]},
)
@test_util.run_in_graph_and_eager_modes
def test_numeric_column(
self, sparse_input_args, expected_input_layer, expected_sequence_length):
sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
numeric_column = sfc.sequence_numeric_column('aaa')
sequence_input_layer = sfc.SequenceFeatures([numeric_column])
input_layer, sequence_length = sequence_input_layer({'aaa': sparse_input})
self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
self.assertAllEqual(
expected_sequence_length, self.evaluate(sequence_length))
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args': {
# example 0, values [0., 1., 2., 3., 4., 5., 6., 7.]
# example 1, [10., 11., 12., 13.]
'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
(0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
'dense_shape': (2, 8)},
'expected_input_layer': [
# The output of numeric_column._get_dense_tensor should be flattened.
[[0., 1., 2., 3.], [4., 5., 6., 7.]],
[[10., 11., 12., 13.], [0., 0., 0., 0.]]],
'expected_sequence_length': [2, 1]},
{'testcase_name': '3D',
'sparse_input_args': {
# example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]]
# example 1, [[10., 11., 12., 13.], []]
'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3),
(0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3),
(1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)),
'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
'dense_shape': (2, 2, 4)},
'expected_input_layer': [
# The output of numeric_column._get_dense_tensor should be flattened.
[[0., 1., 2., 3.], [4., 5., 6., 7.]],
[[10., 11., 12., 13.], [0., 0., 0., 0.]]],
'expected_sequence_length': [2, 1]},
)
@test_util.run_in_graph_and_eager_modes
def test_numeric_column_multi_dim(
self, sparse_input_args, expected_input_layer, expected_sequence_length):
"""Tests SequenceFeatures for multi-dimensional numeric_column."""
sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
sequence_input_layer = sfc.SequenceFeatures([numeric_column])
input_layer, sequence_length = sequence_input_layer({'aaa': sparse_input})
self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
self.assertAllEqual(
expected_sequence_length, self.evaluate(sequence_length))
@test_util.run_in_graph_and_eager_modes
def test_sequence_length_not_equal(self):
"""Tests that an error is raised when sequence lengths are not equal."""
# Input a with sequence_length = [2, 1]
sparse_input_a = sparse_tensor.SparseTensorValue(
indices=((0, 0), (0, 1), (1, 0)),
values=(0., 1., 10.),
dense_shape=(2, 2))
# Input b with sequence_length = [1, 1]
sparse_input_b = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0)),
values=(1., 10.),
dense_shape=(2, 2))
numeric_column_a = sfc.sequence_numeric_column('aaa')
numeric_column_b = sfc.sequence_numeric_column('bbb')
sequence_input_layer = sfc.SequenceFeatures(
[numeric_column_a, numeric_column_b])
with self.assertRaisesRegexp(
errors.InvalidArgumentError, r'Condition x == y did not hold.*'):
_, sequence_length = sequence_input_layer({
'aaa': sparse_input_a,
'bbb': sparse_input_b
})
self.evaluate(sequence_length)
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args': {
# example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
# example 1, [[[10., 11.], [12., 13.]]]
'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
(0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
'dense_shape': (2, 8)},
'expected_shape': [2, 2, 4]},
{'testcase_name': '3D',
'sparse_input_args': {
# example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]]
# example 1, [[10., 11., 12., 13.], []]
'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3),
(0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3),
(1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)),
'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
'dense_shape': (2, 2, 4)},
'expected_shape': [2, 2, 4]},
)
@test_util.run_in_graph_and_eager_modes
def test_static_shape_from_tensors_numeric(
self, sparse_input_args, expected_shape):
"""Tests that we return a known static shape when we have one."""
sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
sequence_input_layer = sfc.SequenceFeatures([numeric_column])
input_layer, _ = sequence_input_layer({'aaa': sparse_input})
shape = input_layer.get_shape()
self.assertEqual(shape, expected_shape)
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args': {
# example 0, ids [2]
# example 1, ids [0, 1]
# example 2, ids []
# example 3, ids [1]
'indices': ((0, 0), (1, 0), (1, 1), (3, 0)),
'values': (2, 0, 1, 1),
'dense_shape': (4, 2)},
'expected_shape': [4, 2, 3]},
{'testcase_name': '3D',
'sparse_input_args': {
# example 0, ids [[2]]
# example 1, ids [[0, 1], [2]]
# example 2, ids []
# example 3, ids [[1], [0, 2]]
'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0),
(3, 0, 0), (3, 1, 0), (3, 1, 1)),
'values': (2, 0, 1, 2, 1, 0, 2),
'dense_shape': (4, 2, 2)},
'expected_shape': [4, 2, 3]}
)
@test_util.run_in_graph_and_eager_modes
def test_static_shape_from_tensors_indicator(
self, sparse_input_args, expected_shape):
"""Tests that we return a known static shape when we have one."""
sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=3)
indicator_column = fc.indicator_column(categorical_column)
sequence_input_layer = sfc.SequenceFeatures([indicator_column])
input_layer, _ = sequence_input_layer({'aaa': sparse_input})
shape = input_layer.get_shape()
self.assertEqual(shape, expected_shape)
@test_util.run_in_graph_and_eager_modes
def test_compute_output_shape(self):
price1 = sfc.sequence_numeric_column('price1', shape=2)
price2 = sfc.sequence_numeric_column('price2')
features = {
'price1': sparse_tensor.SparseTensor(
indices=[[0, 0, 0], [0, 0, 1],
[0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1],
[2, 0, 0], [2, 0, 1],
[3, 0, 0], [3, 0, 1]],
values=[0., 1., 10., 11., 100., 101., 200., 201., 300., 301.],
dense_shape=(4, 3, 2)),
'price2': sparse_tensor.SparseTensor(
indices=[[0, 0],
[0, 1],
[1, 0],
[2, 0],
[3, 0]],
values=[10., 11., 20., 30., 40.],
dense_shape=(4, 3))}
sequence_features = sfc.SequenceFeatures([price1, price2])
seq_input, seq_len = sequence_features(features)
self.assertEqual(
sequence_features.compute_output_shape((None, None)),
(None, None, 3))
self.evaluate(variables_lib.global_variables_initializer())
self.evaluate(lookup_ops.tables_initializer())
self.assertAllClose([[[0., 1., 10.], [10., 11., 11.], [0., 0., 0.]],
[[100., 101., 20.], [0., 0., 0.], [0., 0., 0.]],
[[200., 201., 30.], [0., 0., 0.], [0., 0., 0.]],
[[300., 301., 40.], [0., 0., 0.], [0., 0., 0.]]],
self.evaluate(seq_input))
self.assertAllClose([2, 1, 1, 1], self.evaluate(seq_len))
@test_util.run_all_in_graph_and_eager_modes
class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase):
"""Tests the utility fn concatenate_context_input."""

View File

@ -22,7 +22,6 @@ from absl.testing import parameterized
from tensorflow.python.feature_column import dense_features
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.feature_column import sequence_feature_column as sfc
from tensorflow.python.feature_column import serialization
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@ -180,40 +179,6 @@ class DenseFeaturesSerializationTest(test.TestCase, parameterized.TestCase):
self.assertEqual(new_layer._feature_columns[0].name, 'a_X_b_indicator')
@test_util.run_all_in_graph_and_eager_modes
class SequenceFeaturesSerializationTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('default', None, None),
('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_get_config(self, trainable, name):
cols = [sfc.sequence_numeric_column('a')]
orig_layer = sfc.SequenceFeatures(cols, trainable=trainable, name=name)
config = orig_layer.get_config()
self.assertEqual(config['name'], orig_layer.name)
self.assertEqual(config['trainable'], trainable)
self.assertLen(config['feature_columns'], 1)
self.assertEqual(config['feature_columns'][0]['class_name'],
'SequenceNumericColumn')
self.assertEqual(config['feature_columns'][0]['config']['shape'], (1,))
@parameterized.named_parameters(('default', None, None),
('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_from_config(self, trainable, name):
cols = [sfc.sequence_numeric_column('a')]
orig_layer = sfc.SequenceFeatures(cols, trainable=trainable, name=name)
config = orig_layer.get_config()
new_layer = sfc.SequenceFeatures.from_config(config)
self.assertEqual(new_layer.name, orig_layer.name)
self.assertEqual(new_layer.trainable, trainable)
self.assertLen(new_layer._feature_columns, 1)
self.assertEqual(new_layer._feature_columns[0].name, 'a')
@test_util.run_all_in_graph_and_eager_modes
class LinearModelLayerSerializationTest(test.TestCase, parameterized.TestCase):

View File

@ -28,6 +28,7 @@ py_library(
"//tensorflow/python/eager:monitoring",
"//tensorflow/python/keras/applications",
"//tensorflow/python/keras/datasets",
"//tensorflow/python/keras/feature_column",
"//tensorflow/python/keras/layers",
"//tensorflow/python/keras/mixed_precision/experimental:mixed_precision_experimental",
"//tensorflow/python/keras/optimizer_v2",

View File

@ -47,6 +47,7 @@ keras_packages = [
"tensorflow.python.keras.engine.sequential",
"tensorflow.python.keras.engine.training",
"tensorflow.python.keras.estimator",
"tensorflow.python.keras.feature_column.sequence_feature_column",
"tensorflow.python.keras.initializers",
"tensorflow.python.keras.initializers.initializers_v1",
"tensorflow.python.keras.initializers.initializers_v2",

View File

@ -0,0 +1,74 @@
load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
package(
default_visibility = [
"//tensorflow/python/feature_column:__subpackages__",
"//tensorflow/python/keras:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
py_library(
name = "feature_column",
deps = [
":sequence_feature_column",
],
)
py_library(
name = "sequence_feature_column",
srcs = ["sequence_feature_column.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tf_export",
"//tensorflow/python/feature_column:feature_column_v2",
"//tensorflow/python/keras:backend",
],
)
tf_py_test(
name = "sequence_feature_column_test",
srcs = ["sequence_feature_column_test.py"],
deps = [
":sequence_feature_column",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:variables",
"//tensorflow/python/feature_column:feature_column_v2",
"//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "sequence_feature_column_integration_test",
srcs = ["sequence_feature_column_integration_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":sequence_feature_column",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/feature_column:feature_column_v2",
"//tensorflow/python/keras/layers:recurrent",
],
)

View File

@ -0,0 +1,173 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This API defines FeatureColumn for sequential input.
NOTE: This API is a work in progress and will likely be changing frequently.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.util.tf_export import keras_export
# pylint: disable=protected-access
@keras_export('keras.experimental.SequenceFeatures')
class SequenceFeatures(fc._BaseFeaturesLayer):
"""A layer for sequence input.
All `feature_columns` must be sequence dense columns with the same
`sequence_length`. The output of this method can be fed into sequence
networks, such as RNN.
The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`.
`T` is the maximum sequence length for this batch, which could differ from
batch to batch.
If multiple `feature_columns` are given with `Di` `num_elements` each, their
outputs are concatenated. So, the final `Tensor` has shape
`[batch_size, T, D0 + D1 + ... + Dn]`.
Example:
```python
# Behavior of some cells or feature columns may depend on whether we are in
# training or inference mode, e.g. applying dropout.
training = True
rating = sequence_numeric_column('rating')
watches = sequence_categorical_column_with_identity(
'watches', num_buckets=1000)
watches_embedding = embedding_column(watches, dimension=10)
columns = [rating, watches_embedding]
sequence_input_layer = SequenceFeatures(columns)
features = tf.io.parse_example(...,
features=make_parse_example_spec(columns))
sequence_input, sequence_length = sequence_input_layer(
features, training=training)
sequence_length_mask = tf.sequence_mask(sequence_length)
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size, training=training)
rnn_layer = tf.keras.layers.RNN(rnn_cell, training=training)
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
```
"""
def __init__(
self,
feature_columns,
trainable=True,
name=None,
**kwargs):
""""Constructs a SequenceFeatures layer.
Args:
feature_columns: An iterable of dense sequence columns. Valid columns are
- `embedding_column` that wraps a `sequence_categorical_column_with_*`
- `sequence_numeric_column`.
trainable: Boolean, whether the layer's variables will be updated via
gradient descent during training.
name: Name to give to the SequenceFeatures.
**kwargs: Keyword arguments to construct a layer.
Raises:
ValueError: If any of the `feature_columns` is not a
`SequenceDenseColumn`.
"""
super(SequenceFeatures, self).__init__(
feature_columns=feature_columns,
trainable=trainable,
name=name,
expected_column_type=fc.SequenceDenseColumn,
**kwargs)
@property
def _is_feature_layer(self):
return True
def _target_shape(self, input_shape, total_elements):
return (input_shape[0], input_shape[1], total_elements)
def call(self, features, training=None):
"""Returns sequence input corresponding to the `feature_columns`.
Args:
features: A dict mapping keys to tensors.
training: Python boolean or None, indicating whether to the layer is being
run in training mode. This argument is passed to the call method of any
`FeatureColumn` that takes a `training` argument. For example, if a
`FeatureColumn` performed dropout, the column could expose a `training`
argument to control whether the dropout should be applied. If `None`,
defaults to `tf.keras.backend.learning_phase()`.
Returns:
An `(input_layer, sequence_length)` tuple where:
- input_layer: A float `Tensor` of shape `[batch_size, T, D]`.
`T` is the maximum sequence length for this batch, which could differ
from batch to batch. `D` is the sum of `num_elements` for all
`feature_columns`.
- sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence
length for each example.
Raises:
ValueError: If features are not a dictionary.
"""
if not isinstance(features, dict):
raise ValueError('We expected a dictionary here. Instead we got: ',
features)
if training is None:
training = backend.learning_phase()
transformation_cache = fc.FeatureTransformationCache(features)
output_tensors = []
sequence_lengths = []
for column in self._feature_columns:
with ops.name_scope(column.name):
try:
dense_tensor, sequence_length = column.get_sequence_dense_tensor(
transformation_cache, self._state_manager, training=training)
except TypeError:
dense_tensor, sequence_length = column.get_sequence_dense_tensor(
transformation_cache, self._state_manager)
# Flattens the final dimension to produce a 3D Tensor.
output_tensors.append(self._process_dense_tensor(column, dense_tensor))
sequence_lengths.append(sequence_length)
# Check and process sequence lengths.
fc._verify_static_batch_size_equality(sequence_lengths,
self._feature_columns)
sequence_length = _assert_all_equal_and_return(sequence_lengths)
return self._verify_and_concat_tensors(output_tensors), sequence_length
def _assert_all_equal_and_return(tensors, name=None):
"""Asserts that all tensors are equal and returns the first one."""
with ops.name_scope(name, 'assert_all_equal', values=tensors):
if len(tensors) == 1:
return tensors[0]
assert_equal_ops = []
for t in tensors[1:]:
assert_equal_ops.append(check_ops.assert_equal(tensors[0], t))
with ops.control_dependencies(assert_equal_ops):
return array_ops.identity(tensors[0])

View File

@ -0,0 +1,259 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration test for sequence feature columns with SequenceExamples."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from google.protobuf import text_format
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.feature_column import dense_features
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.feature_column import sequence_feature_column as sfc
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class SequenceFeatureColumnIntegrationTest(test.TestCase):
def _make_sequence_example(self):
example = example_pb2.SequenceExample()
example.context.feature['int_ctx'].int64_list.value.extend([5])
example.context.feature['float_ctx'].float_list.value.extend([123.6])
for val in range(0, 10, 2):
feat = feature_pb2.Feature()
feat.int64_list.value.extend([val] * val)
example.feature_lists.feature_list['int_list'].feature.extend([feat])
for val in range(1, 11, 2):
feat = feature_pb2.Feature()
feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val)
example.feature_lists.feature_list['str_list'].feature.extend([feat])
return example
def _build_feature_columns(self):
col = fc.categorical_column_with_identity('int_ctx', num_buckets=100)
ctx_cols = [
fc.embedding_column(col, dimension=10),
fc.numeric_column('float_ctx')
]
identity_col = sfc.sequence_categorical_column_with_identity(
'int_list', num_buckets=10)
bucket_col = sfc.sequence_categorical_column_with_hash_bucket(
'bytes_list', hash_bucket_size=100)
seq_cols = [
fc.embedding_column(identity_col, dimension=10),
fc.embedding_column(bucket_col, dimension=20)
]
return ctx_cols, seq_cols
def test_sequence_example_into_input_layer(self):
examples = [_make_sequence_example().SerializeToString()] * 100
ctx_cols, seq_cols = self._build_feature_columns()
def _parse_example(example):
ctx, seq = parsing_ops.parse_single_sequence_example(
example,
context_features=fc.make_parse_example_spec_v2(ctx_cols),
sequence_features=fc.make_parse_example_spec_v2(seq_cols))
ctx.update(seq)
return ctx
ds = dataset_ops.Dataset.from_tensor_slices(examples)
ds = ds.map(_parse_example)
ds = ds.batch(20)
# Test on a single batch
features = dataset_ops.make_one_shot_iterator(ds).get_next()
# Tile the context features across the sequence features
sequence_input_layer = ksfc.SequenceFeatures(seq_cols)
seq_layer, _ = sequence_input_layer(features)
input_layer = dense_features.DenseFeatures(ctx_cols)
ctx_layer = input_layer(features)
input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer)
rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10))
output = rnn_layer(input_layer)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
features_r = sess.run(features)
self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6])
output_r = sess.run(output)
self.assertAllEqual(output_r.shape, [20, 10])
@test_util.run_deprecated_v1
def test_shared_sequence_non_sequence_into_input_layer(self):
non_seq = fc.categorical_column_with_identity('non_seq',
num_buckets=10)
seq = sfc.sequence_categorical_column_with_identity('seq',
num_buckets=10)
shared_non_seq, shared_seq = fc.shared_embedding_columns_v2(
[non_seq, seq],
dimension=4,
combiner='sum',
initializer=init_ops_v2.Ones(),
shared_embedding_collection_name='shared')
seq = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=[0, 1, 2],
dense_shape=[2, 2])
non_seq = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=[0, 1, 2],
dense_shape=[2, 2])
features = {'seq': seq, 'non_seq': non_seq}
# Tile the context features across the sequence features
seq_input, seq_length = ksfc.SequenceFeatures([shared_seq])(features)
non_seq_input = dense_features.DenseFeatures([shared_non_seq])(features)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output_seq, output_seq_length, output_non_seq = sess.run(
[seq_input, seq_length, non_seq_input])
self.assertAllEqual(output_seq, [[[1, 1, 1, 1], [1, 1, 1, 1]],
[[1, 1, 1, 1], [0, 0, 0, 0]]])
self.assertAllEqual(output_seq_length, [2, 1])
self.assertAllEqual(output_non_seq, [[2, 2, 2, 2], [1, 1, 1, 1]])
_SEQ_EX_PROTO = """
context {
feature {
key: "float_ctx"
value {
float_list {
value: 123.6
}
}
}
feature {
key: "int_ctx"
value {
int64_list {
value: 5
}
}
}
}
feature_lists {
feature_list {
key: "bytes_list"
value {
feature {
bytes_list {
value: "a"
}
}
feature {
bytes_list {
value: "b"
value: "c"
}
}
feature {
bytes_list {
value: "d"
value: "e"
value: "f"
value: "g"
}
}
}
}
feature_list {
key: "float_list"
value {
feature {
float_list {
value: 1.0
}
}
feature {
float_list {
value: 3.0
value: 3.0
value: 3.0
}
}
feature {
float_list {
value: 5.0
value: 5.0
value: 5.0
value: 5.0
value: 5.0
}
}
}
}
feature_list {
key: "int_list"
value {
feature {
int64_list {
value: 2
value: 2
}
}
feature {
int64_list {
value: 4
value: 4
value: 4
value: 4
}
}
feature {
int64_list {
value: 6
value: 6
value: 6
value: 6
value: 6
value: 6
}
}
}
}
}
"""
def _make_sequence_example():
example = example_pb2.SequenceExample()
return text_format.Parse(_SEQ_EX_PROTO, example)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,687 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for sequential_feature_column."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.feature_column import sequence_feature_column as sfc
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc
from tensorflow.python.keras.saving import model_config
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
def _initialized_session(config=None):
sess = session.Session(config=config)
sess.run(variables_lib.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
return sess
class SequenceFeaturesTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args_a': {
# example 0, ids [2]
# example 1, ids [0, 1]
'indices': ((0, 0), (1, 0), (1, 1)),
'values': (2, 0, 1),
'dense_shape': (2, 2)},
'sparse_input_args_b': {
# example 0, ids [1]
# example 1, ids [2, 0]
'indices': ((0, 0), (1, 0), (1, 1)),
'values': (1, 2, 0),
'dense_shape': (2, 2)},
'expected_input_layer': [
# example 0, ids_a [2], ids_b [1]
[[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]],
# example 1, ids_a [0, 1], ids_b [2, 0]
[[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],],
'expected_sequence_length': [1, 2]},
{'testcase_name': '3D',
'sparse_input_args_a': {
# feature 0, ids [[2], [0, 1]]
# feature 1, ids [[0, 0], [1]]
'indices': (
(0, 0, 0), (0, 1, 0), (0, 1, 1),
(1, 0, 0), (1, 0, 1), (1, 1, 0)),
'values': (2, 0, 1, 0, 0, 1),
'dense_shape': (2, 2, 2)},
'sparse_input_args_b': {
# feature 0, ids [[1, 1], [1]]
# feature 1, ids [[2], [0]]
'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
'values': (1, 1, 1, 2, 0),
'dense_shape': (2, 2, 2)},
'expected_input_layer': [
# feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -]
[[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]],
# feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -]
[[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]],
'expected_sequence_length': [2, 2]},
)
@test_util.run_in_graph_and_eager_modes
def test_embedding_column(
self, sparse_input_args_a, sparse_input_args_b, expected_input_layer,
expected_sequence_length):
sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a)
sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b)
vocabulary_size = 3
embedding_dimension_a = 2
embedding_values_a = (
(1., 2.), # id 0
(3., 4.), # id 1
(5., 6.) # id 2
)
embedding_dimension_b = 3
embedding_values_b = (
(11., 12., 13.), # id 0
(14., 15., 16.), # id 1
(17., 18., 19.) # id 2
)
def _get_initializer(embedding_dimension, embedding_values):
def _initializer(shape, dtype, partition_info=None):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
self.assertIsNone(partition_info)
return embedding_values
return _initializer
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column_a = fc.embedding_column(
categorical_column_a,
dimension=embedding_dimension_a,
initializer=_get_initializer(embedding_dimension_a, embedding_values_a))
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
embedding_column_b = fc.embedding_column(
categorical_column_b,
dimension=embedding_dimension_b,
initializer=_get_initializer(embedding_dimension_b, embedding_values_b))
# Test that columns are reordered alphabetically.
sequence_input_layer = ksfc.SequenceFeatures(
[embedding_column_b, embedding_column_a])
input_layer, sequence_length = sequence_input_layer({
'aaa': sparse_input_a, 'bbb': sparse_input_b,})
self.evaluate(variables_lib.global_variables_initializer())
weights = sequence_input_layer.weights
self.assertCountEqual(
('sequence_features/aaa_embedding/embedding_weights:0',
'sequence_features/bbb_embedding/embedding_weights:0'),
tuple([v.name for v in weights]))
self.assertAllEqual(embedding_values_a, self.evaluate(weights[0]))
self.assertAllEqual(embedding_values_b, self.evaluate(weights[1]))
self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
self.assertAllEqual(
expected_sequence_length, self.evaluate(sequence_length))
@test_util.run_in_graph_and_eager_modes
def test_embedding_column_with_non_sequence_categorical(self):
"""Tests that error is raised for non-sequence embedding column."""
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column_a = fc.embedding_column(
categorical_column_a, dimension=2)
sequence_input_layer = ksfc.SequenceFeatures([embedding_column_a])
with self.assertRaisesRegexp(
ValueError,
r'In embedding_column: aaa_embedding\. categorical_column must be of '
r'type SequenceCategoricalColumn to use SequenceFeatures\.'):
_, _ = sequence_input_layer({'aaa': sparse_input})
@test_util.run_in_graph_and_eager_modes
def test_shared_embedding_column(self):
with ops.Graph().as_default():
vocabulary_size = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
sparse_input_b = sparse_tensor.SparseTensorValue(
# example 0, ids [1]
# example 1, ids [2, 0]
indices=((0, 0), (1, 0), (1, 1)),
values=(1, 2, 0),
dense_shape=(2, 2))
embedding_dimension = 2
embedding_values = (
(1., 2.), # id 0
(3., 4.), # id 1
(5., 6.) # id 2
)
def _get_initializer(embedding_dimension, embedding_values):
def _initializer(shape, dtype, partition_info=None):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
self.assertIsNone(partition_info)
return embedding_values
return _initializer
expected_input_layer = [
# example 0, ids_a [2], ids_b [1]
[[5., 6., 3., 4.], [0., 0., 0., 0.]],
# example 1, ids_a [0, 1], ids_b [2, 0]
[[1., 2., 5., 6.], [3., 4., 1., 2.]],
]
expected_sequence_length = [1, 2]
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
# Test that columns are reordered alphabetically.
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension,
initializer=_get_initializer(embedding_dimension, embedding_values))
sequence_input_layer = ksfc.SequenceFeatures(shared_embedding_columns)
input_layer, sequence_length = sequence_input_layer({
'aaa': sparse_input_a, 'bbb': sparse_input_b})
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertCountEqual(
('aaa_bbb_shared_embedding:0',),
tuple([v.name for v in global_vars]))
with _initialized_session() as sess:
self.assertAllEqual(embedding_values,
global_vars[0].eval(session=sess))
self.assertAllEqual(expected_input_layer,
input_layer.eval(session=sess))
self.assertAllEqual(
expected_sequence_length, sequence_length.eval(session=sess))
@test_util.run_deprecated_v1
def test_shared_embedding_column_with_non_sequence_categorical(self):
"""Tests that error is raised for non-sequence shared embedding column."""
vocabulary_size = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
sparse_input_b = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b], dimension=2)
sequence_input_layer = ksfc.SequenceFeatures(shared_embedding_columns)
with self.assertRaisesRegexp(
ValueError,
r'In embedding_column: aaa_shared_embedding\. categorical_column must '
r'be of type SequenceCategoricalColumn to use SequenceFeatures\.'):
_, _ = sequence_input_layer({'aaa': sparse_input_a,
'bbb': sparse_input_b})
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args_a': {
# example 0, ids [2]
# example 1, ids [0, 1]
'indices': ((0, 0), (1, 0), (1, 1)),
'values': (2, 0, 1),
'dense_shape': (2, 2)},
'sparse_input_args_b': {
# example 0, ids [1]
# example 1, ids [1, 0]
'indices': ((0, 0), (1, 0), (1, 1)),
'values': (1, 1, 0),
'dense_shape': (2, 2)},
'expected_input_layer': [
# example 0, ids_a [2], ids_b [1]
[[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]],
# example 1, ids_a [0, 1], ids_b [1, 0]
[[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]],
'expected_sequence_length': [1, 2]},
{'testcase_name': '3D',
'sparse_input_args_a': {
# feature 0, ids [[2], [0, 1]]
# feature 1, ids [[0, 0], [1]]
'indices': (
(0, 0, 0), (0, 1, 0), (0, 1, 1),
(1, 0, 0), (1, 0, 1), (1, 1, 0)),
'values': (2, 0, 1, 0, 0, 1),
'dense_shape': (2, 2, 2)},
'sparse_input_args_b': {
# feature 0, ids [[1, 1], [1]]
# feature 1, ids [[1], [0]]
'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
'values': (1, 1, 1, 1, 0),
'dense_shape': (2, 2, 2)},
'expected_input_layer': [
# feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -]
[[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]],
# feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -]
[[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]],
'expected_sequence_length': [2, 2]},
)
@test_util.run_in_graph_and_eager_modes
def test_indicator_column(
self, sparse_input_args_a, sparse_input_args_b, expected_input_layer,
expected_sequence_length):
sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a)
sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b)
vocabulary_size_a = 3
vocabulary_size_b = 2
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size_a)
indicator_column_a = fc.indicator_column(categorical_column_a)
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size_b)
indicator_column_b = fc.indicator_column(categorical_column_b)
# Test that columns are reordered alphabetically.
sequence_input_layer = ksfc.SequenceFeatures(
[indicator_column_b, indicator_column_a])
input_layer, sequence_length = sequence_input_layer({
'aaa': sparse_input_a, 'bbb': sparse_input_b})
self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
self.assertAllEqual(
expected_sequence_length, self.evaluate(sequence_length))
@test_util.run_in_graph_and_eager_modes
def test_indicator_column_with_non_sequence_categorical(self):
"""Tests that error is raised for non-sequence categorical column."""
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 1),
dense_shape=(2, 2))
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column_a = fc.indicator_column(categorical_column_a)
sequence_input_layer = ksfc.SequenceFeatures([indicator_column_a])
with self.assertRaisesRegexp(
ValueError,
r'In indicator_column: aaa_indicator\. categorical_column must be of '
r'type SequenceCategoricalColumn to use SequenceFeatures\.'):
_, _ = sequence_input_layer({'aaa': sparse_input})
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args': {
# example 0, values [0., 1]
# example 1, [10.]
'indices': ((0, 0), (0, 1), (1, 0)),
'values': (0., 1., 10.),
'dense_shape': (2, 2)},
'expected_input_layer': [
[[0.], [1.]],
[[10.], [0.]]],
'expected_sequence_length': [2, 1]},
{'testcase_name': '3D',
'sparse_input_args': {
# feature 0, ids [[20, 3], [5]]
# feature 1, ids [[3], [8]]
'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
'values': (20., 3., 5., 3., 8.),
'dense_shape': (2, 2, 2)},
'expected_input_layer': [
[[20.], [3.], [5.], [0.]],
[[3.], [0.], [8.], [0.]]],
'expected_sequence_length': [2, 2]},
)
@test_util.run_in_graph_and_eager_modes
def test_numeric_column(
self, sparse_input_args, expected_input_layer, expected_sequence_length):
sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
numeric_column = sfc.sequence_numeric_column('aaa')
sequence_input_layer = ksfc.SequenceFeatures([numeric_column])
input_layer, sequence_length = sequence_input_layer({'aaa': sparse_input})
self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
self.assertAllEqual(
expected_sequence_length, self.evaluate(sequence_length))
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args': {
# example 0, values [0., 1., 2., 3., 4., 5., 6., 7.]
# example 1, [10., 11., 12., 13.]
'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
(0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
'dense_shape': (2, 8)},
'expected_input_layer': [
# The output of numeric_column._get_dense_tensor should be flattened.
[[0., 1., 2., 3.], [4., 5., 6., 7.]],
[[10., 11., 12., 13.], [0., 0., 0., 0.]]],
'expected_sequence_length': [2, 1]},
{'testcase_name': '3D',
'sparse_input_args': {
# example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]]
# example 1, [[10., 11., 12., 13.], []]
'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3),
(0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3),
(1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)),
'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
'dense_shape': (2, 2, 4)},
'expected_input_layer': [
# The output of numeric_column._get_dense_tensor should be flattened.
[[0., 1., 2., 3.], [4., 5., 6., 7.]],
[[10., 11., 12., 13.], [0., 0., 0., 0.]]],
'expected_sequence_length': [2, 1]},
)
@test_util.run_in_graph_and_eager_modes
def test_numeric_column_multi_dim(
self, sparse_input_args, expected_input_layer, expected_sequence_length):
"""Tests SequenceFeatures for multi-dimensional numeric_column."""
sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
sequence_input_layer = ksfc.SequenceFeatures([numeric_column])
input_layer, sequence_length = sequence_input_layer({'aaa': sparse_input})
self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
self.assertAllEqual(
expected_sequence_length, self.evaluate(sequence_length))
@test_util.run_in_graph_and_eager_modes
def test_sequence_length_not_equal(self):
"""Tests that an error is raised when sequence lengths are not equal."""
# Input a with sequence_length = [2, 1]
sparse_input_a = sparse_tensor.SparseTensorValue(
indices=((0, 0), (0, 1), (1, 0)),
values=(0., 1., 10.),
dense_shape=(2, 2))
# Input b with sequence_length = [1, 1]
sparse_input_b = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0)),
values=(1., 10.),
dense_shape=(2, 2))
numeric_column_a = sfc.sequence_numeric_column('aaa')
numeric_column_b = sfc.sequence_numeric_column('bbb')
sequence_input_layer = ksfc.SequenceFeatures(
[numeric_column_a, numeric_column_b])
with self.assertRaisesRegexp(
errors.InvalidArgumentError, r'Condition x == y did not hold.*'):
_, sequence_length = sequence_input_layer({
'aaa': sparse_input_a,
'bbb': sparse_input_b
})
self.evaluate(sequence_length)
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args': {
# example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
# example 1, [[[10., 11.], [12., 13.]]]
'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
(0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
'dense_shape': (2, 8)},
'expected_shape': [2, 2, 4]},
{'testcase_name': '3D',
'sparse_input_args': {
# example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]]
# example 1, [[10., 11., 12., 13.], []]
'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3),
(0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3),
(1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)),
'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
'dense_shape': (2, 2, 4)},
'expected_shape': [2, 2, 4]},
)
@test_util.run_in_graph_and_eager_modes
def test_static_shape_from_tensors_numeric(
self, sparse_input_args, expected_shape):
"""Tests that we return a known static shape when we have one."""
sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
sequence_input_layer = ksfc.SequenceFeatures([numeric_column])
input_layer, _ = sequence_input_layer({'aaa': sparse_input})
shape = input_layer.get_shape()
self.assertEqual(shape, expected_shape)
@parameterized.named_parameters(
{'testcase_name': '2D',
'sparse_input_args': {
# example 0, ids [2]
# example 1, ids [0, 1]
# example 2, ids []
# example 3, ids [1]
'indices': ((0, 0), (1, 0), (1, 1), (3, 0)),
'values': (2, 0, 1, 1),
'dense_shape': (4, 2)},
'expected_shape': [4, 2, 3]},
{'testcase_name': '3D',
'sparse_input_args': {
# example 0, ids [[2]]
# example 1, ids [[0, 1], [2]]
# example 2, ids []
# example 3, ids [[1], [0, 2]]
'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0),
(3, 0, 0), (3, 1, 0), (3, 1, 1)),
'values': (2, 0, 1, 2, 1, 0, 2),
'dense_shape': (4, 2, 2)},
'expected_shape': [4, 2, 3]}
)
@test_util.run_in_graph_and_eager_modes
def test_static_shape_from_tensors_indicator(
self, sparse_input_args, expected_shape):
"""Tests that we return a known static shape when we have one."""
sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=3)
indicator_column = fc.indicator_column(categorical_column)
sequence_input_layer = ksfc.SequenceFeatures([indicator_column])
input_layer, _ = sequence_input_layer({'aaa': sparse_input})
shape = input_layer.get_shape()
self.assertEqual(shape, expected_shape)
@test_util.run_in_graph_and_eager_modes
def test_compute_output_shape(self):
price1 = sfc.sequence_numeric_column('price1', shape=2)
price2 = sfc.sequence_numeric_column('price2')
features = {
'price1': sparse_tensor.SparseTensor(
indices=[[0, 0, 0], [0, 0, 1],
[0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1],
[2, 0, 0], [2, 0, 1],
[3, 0, 0], [3, 0, 1]],
values=[0., 1., 10., 11., 100., 101., 200., 201., 300., 301.],
dense_shape=(4, 3, 2)),
'price2': sparse_tensor.SparseTensor(
indices=[[0, 0],
[0, 1],
[1, 0],
[2, 0],
[3, 0]],
values=[10., 11., 20., 30., 40.],
dense_shape=(4, 3))}
sequence_features = ksfc.SequenceFeatures([price1, price2])
seq_input, seq_len = sequence_features(features)
self.assertEqual(
sequence_features.compute_output_shape((None, None)),
(None, None, 3))
self.evaluate(variables_lib.global_variables_initializer())
self.evaluate(lookup_ops.tables_initializer())
self.assertAllClose([[[0., 1., 10.], [10., 11., 11.], [0., 0., 0.]],
[[100., 101., 20.], [0., 0., 0.], [0., 0., 0.]],
[[200., 201., 30.], [0., 0., 0.], [0., 0., 0.]],
[[300., 301., 40.], [0., 0., 0.], [0., 0., 0.]]],
self.evaluate(seq_input))
self.assertAllClose([2, 1, 1, 1], self.evaluate(seq_len))
@test_util.run_all_in_graph_and_eager_modes
class SequenceFeaturesSerializationTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('default', None, None),
('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_get_config(self, trainable, name):
cols = [sfc.sequence_numeric_column('a')]
orig_layer = ksfc.SequenceFeatures(cols, trainable=trainable, name=name)
config = orig_layer.get_config()
self.assertEqual(config['name'], orig_layer.name)
self.assertEqual(config['trainable'], trainable)
self.assertLen(config['feature_columns'], 1)
self.assertEqual(config['feature_columns'][0]['class_name'],
'SequenceNumericColumn')
self.assertEqual(config['feature_columns'][0]['config']['shape'], (1,))
@parameterized.named_parameters(('default', None, None),
('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_from_config(self, trainable, name):
cols = [sfc.sequence_numeric_column('a')]
orig_layer = ksfc.SequenceFeatures(cols, trainable=trainable, name=name)
config = orig_layer.get_config()
new_layer = ksfc.SequenceFeatures.from_config(config)
self.assertEqual(new_layer.name, orig_layer.name)
self.assertEqual(new_layer.trainable, trainable)
self.assertLen(new_layer._feature_columns, 1)
self.assertEqual(new_layer._feature_columns[0].name, 'a')
def test_serialization_sequence_features(self):
rating = sfc.sequence_numeric_column('rating')
sequence_feature = ksfc.SequenceFeatures([rating])
config = keras.layers.serialize(sequence_feature)
revived = keras.layers.deserialize(config)
self.assertIsInstance(revived, ksfc.SequenceFeatures)
class SequenceFeaturesSavingTest(test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_saving_with_sequence_features(self):
cols = [
sfc.sequence_numeric_column('a'),
fc.indicator_column(
sfc.sequence_categorical_column_with_vocabulary_list(
'b', ['one', 'two']))
]
input_layers = {
'a':
keras.layers.Input(shape=(None, 1), sparse=True, name='a'),
'b':
keras.layers.Input(
shape=(None, 1), sparse=True, name='b', dtype='string')
}
fc_layer, _ = ksfc.SequenceFeatures(cols)(input_layers)
# TODO(tibell): Figure out the right dtype and apply masking.
# sequence_length_mask = array_ops.sequence_mask(sequence_length)
# x = keras.layers.GRU(32)(fc_layer, mask=sequence_length_mask)
x = keras.layers.GRU(32)(fc_layer)
output = keras.layers.Dense(10)(x)
model = keras.models.Model(input_layers, output)
model.compile(
loss=keras.losses.MSE,
optimizer='rmsprop',
metrics=[keras.metrics.categorical_accuracy])
config = model.to_json()
loaded_model = model_config.model_from_json(config)
batch_size = 10
timesteps = 1
values_a = np.arange(10, dtype=np.float32)
indices_a = np.zeros((10, 3), dtype=np.int64)
indices_a[:, 0] = np.arange(10)
inputs_a = sparse_tensor.SparseTensor(indices_a, values_a,
(batch_size, timesteps, 1))
values_b = np.zeros(10, dtype=np.str)
indices_b = np.zeros((10, 3), dtype=np.int64)
indices_b[:, 0] = np.arange(10)
inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
(batch_size, timesteps, 1))
with self.cached_session():
# Initialize tables for V1 lookup.
if not context.executing_eagerly():
self.evaluate(lookup_ops.tables_initializer())
self.assertLen(
loaded_model.predict({
'a': inputs_a,
'b': inputs_b
}, steps=1), batch_size)
if __name__ == '__main__':
test.main()

View File

@ -122,11 +122,13 @@ def populate_deserializable_objects():
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.feature_column.sequence_feature_column import SequenceFeatures # pylint: disable=g-import-not-at-top
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
LOCAL.ALL_OBJECTS['Network'] = models.Network
LOCAL.ALL_OBJECTS['Model'] = models.Model
LOCAL.ALL_OBJECTS['SequenceFeatures'] = SequenceFeatures
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel
LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel

View File

@ -165,5 +165,6 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
self.assertIsInstance(new_layer, rnn_v1.GRU)
self.assertNotIsInstance(new_layer, rnn_v2.GRU)
if __name__ == '__main__':
test.main()

View File

@ -1,6 +1,6 @@
path: "tensorflow.keras.experimental.SequenceFeatures"
tf_class {
is_instance: "<class \'tensorflow.python.feature_column.sequence_feature_column.SequenceFeatures\'>"
is_instance: "<class \'tensorflow.python.keras.feature_column.sequence_feature_column.SequenceFeatures\'>"
is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"

View File

@ -1,6 +1,6 @@
path: "tensorflow.keras.experimental.SequenceFeatures"
tf_class {
is_instance: "<class \'tensorflow.python.feature_column.sequence_feature_column.SequenceFeatures\'>"
is_instance: "<class \'tensorflow.python.keras.feature_column.sequence_feature_column.SequenceFeatures\'>"
is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"