Remove private concatenate_context_input from sequence column integration test

PiperOrigin-RevId: 353323428
Change-Id: I8637c3db031fffc6ac51e584ec23d153cafceff3
This commit is contained in:
Matt Watson 2021-01-22 14:47:23 -08:00 committed by TensorFlower Gardener
parent dd8bc0e821
commit 6bb7f19c3d
2 changed files with 12 additions and 5 deletions

View File

@ -167,6 +167,7 @@ tf_py_test(
deps = [
":dense_features",
":sequence_feature_column",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops_v2",
@ -177,6 +178,8 @@ tf_py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/feature_column:feature_column_v2",
"//tensorflow/python/keras:metrics", # Import it here since base_layer didn't import it due to circular dependency.
"//tensorflow/python/keras/layers:core",
"//tensorflow/python/keras/layers:merge",
"//tensorflow/python/keras/layers:recurrent",
],
)

View File

@ -30,7 +30,10 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras.feature_column import dense_features
from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import merge
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
@ -94,13 +97,14 @@ class SequenceFeatureColumnIntegrationTest(test.TestCase):
# Tile the context features across the sequence features
sequence_input_layer = ksfc.SequenceFeatures(seq_cols)
seq_layer, _ = sequence_input_layer(features)
input_layer = dense_features.DenseFeatures(ctx_cols)
ctx_layer = input_layer(features)
input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer)
seq_input, _ = sequence_input_layer(features)
dense_input_layer = dense_features.DenseFeatures(ctx_cols)
ctx_input = dense_input_layer(features)
ctx_input = core.RepeatVector(array_ops.shape(seq_input)[1])(ctx_input)
concatenated_input = merge.concatenate([seq_input, ctx_input])
rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10))
output = rnn_layer(input_layer)
output = rnn_layer(concatenated_input)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())