diff --git a/tensorflow/python/feature_column/sequence_feature_column_test.py b/tensorflow/python/feature_column/sequence_feature_column_test.py index d0cf5ee7670..e0cd73d17e4 100644 --- a/tensorflow/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/python/feature_column/sequence_feature_column_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.client import session +from tensorflow.python.feature_column import feature_column_lib as fc_lib from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.feature_column import sequence_feature_column as sfc from tensorflow.python.feature_column import serialization @@ -31,7 +32,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util -from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops @@ -132,7 +132,8 @@ def _get_sequence_dense_tensor(column, features): def _get_sequence_dense_tensor_state(column, features): - state_manager = fc._StateManagerImpl(Layer(), trainable=True) + state_manager = fc._StateManagerImpl( + fc_lib.DenseFeatures(column), trainable=True) column.create_state(state_manager) dense_tensor, lengths = column.get_sequence_dense_tensor( fc.FeatureTransformationCache(features), state_manager)