Remove private concatenate_context_input from sequence column integration test
PiperOrigin-RevId: 353323428 Change-Id: I8637c3db031fffc6ac51e584ec23d153cafceff3
This commit is contained in:
parent
dd8bc0e821
commit
6bb7f19c3d
@ -167,6 +167,7 @@ tf_py_test(
|
||||
deps = [
|
||||
":dense_features",
|
||||
":sequence_feature_column",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
@ -177,6 +178,8 @@ tf_py_test(
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/feature_column:feature_column_v2",
|
||||
"//tensorflow/python/keras:metrics", # Import it here since base_layer didn't import it due to circular dependency.
|
||||
"//tensorflow/python/keras/layers:core",
|
||||
"//tensorflow/python/keras/layers:merge",
|
||||
"//tensorflow/python/keras/layers:recurrent",
|
||||
],
|
||||
)
|
||||
|
||||
@ -30,7 +30,10 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.feature_column import dense_features
|
||||
from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import merge
|
||||
from tensorflow.python.keras.layers import recurrent
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -94,13 +97,14 @@ class SequenceFeatureColumnIntegrationTest(test.TestCase):
|
||||
|
||||
# Tile the context features across the sequence features
|
||||
sequence_input_layer = ksfc.SequenceFeatures(seq_cols)
|
||||
seq_layer, _ = sequence_input_layer(features)
|
||||
input_layer = dense_features.DenseFeatures(ctx_cols)
|
||||
ctx_layer = input_layer(features)
|
||||
input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer)
|
||||
seq_input, _ = sequence_input_layer(features)
|
||||
dense_input_layer = dense_features.DenseFeatures(ctx_cols)
|
||||
ctx_input = dense_input_layer(features)
|
||||
ctx_input = core.RepeatVector(array_ops.shape(seq_input)[1])(ctx_input)
|
||||
concatenated_input = merge.concatenate([seq_input, ctx_input])
|
||||
|
||||
rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10))
|
||||
output = rnn_layer(input_layer)
|
||||
output = rnn_layer(concatenated_input)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user