From 6bb7f19c3d2c11c39da951f388f9112e0649ecfc Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Fri, 22 Jan 2021 14:47:23 -0800 Subject: [PATCH] Remove private concatenate_context_input from sequence column integration test PiperOrigin-RevId: 353323428 Change-Id: I8637c3db031fffc6ac51e584ec23d153cafceff3 --- tensorflow/python/keras/feature_column/BUILD | 3 +++ .../sequence_feature_column_integration_test.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/feature_column/BUILD b/tensorflow/python/keras/feature_column/BUILD index c55962e5be1..a64f88b639a 100644 --- a/tensorflow/python/keras/feature_column/BUILD +++ b/tensorflow/python/keras/feature_column/BUILD @@ -167,6 +167,7 @@ tf_py_test( deps = [ ":dense_features", ":sequence_feature_column", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:init_ops_v2", @@ -177,6 +178,8 @@ tf_py_test( "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/feature_column:feature_column_v2", "//tensorflow/python/keras:metrics", # Import it here since base_layer didn't import it due to circular dependency. + "//tensorflow/python/keras/layers:core", + "//tensorflow/python/keras/layers:merge", "//tensorflow/python/keras/layers:recurrent", ], ) diff --git a/tensorflow/python/keras/feature_column/sequence_feature_column_integration_test.py b/tensorflow/python/keras/feature_column/sequence_feature_column_integration_test.py index b1100bf7b07..1889a71c2d5 100644 --- a/tensorflow/python/keras/feature_column/sequence_feature_column_integration_test.py +++ b/tensorflow/python/keras/feature_column/sequence_feature_column_integration_test.py @@ -30,7 +30,10 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.keras.feature_column import dense_features from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc +from tensorflow.python.keras.layers import core +from tensorflow.python.keras.layers import merge from tensorflow.python.keras.layers import recurrent +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops_v2 from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables @@ -94,13 +97,14 @@ class SequenceFeatureColumnIntegrationTest(test.TestCase): # Tile the context features across the sequence features sequence_input_layer = ksfc.SequenceFeatures(seq_cols) - seq_layer, _ = sequence_input_layer(features) - input_layer = dense_features.DenseFeatures(ctx_cols) - ctx_layer = input_layer(features) - input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer) + seq_input, _ = sequence_input_layer(features) + dense_input_layer = dense_features.DenseFeatures(ctx_cols) + ctx_input = dense_input_layer(features) + ctx_input = core.RepeatVector(array_ops.shape(seq_input)[1])(ctx_input) + concatenated_input = merge.concatenate([seq_input, ctx_input]) rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10)) - output = rnn_layer(input_layer) + output = rnn_layer(concatenated_input) with self.cached_session() as sess: sess.run(variables.global_variables_initializer())