Update sequence_feature_column_test to not rely on Keras.

PiperOrigin-RevId: 315282586
Change-Id: I0457ae4072aa672ae6be1bfa176b3b9f3b8fea0d
This commit is contained in:
Scott Zhu 2020-06-08 08:47:38 -07:00 committed by TensorFlower Gardener
parent 3b2109f7de
commit ad6ccc651c

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.feature_column import feature_column_lib as fc_lib
from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.feature_column import sequence_feature_column as sfc from tensorflow.python.feature_column import sequence_feature_column as sfc
from tensorflow.python.feature_column import serialization from tensorflow.python.feature_column import serialization
@ -31,7 +32,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -132,7 +132,8 @@ def _get_sequence_dense_tensor(column, features):
def _get_sequence_dense_tensor_state(column, features): def _get_sequence_dense_tensor_state(column, features):
state_manager = fc._StateManagerImpl(Layer(), trainable=True) state_manager = fc._StateManagerImpl(
fc_lib.DenseFeatures(column), trainable=True)
column.create_state(state_manager) column.create_state(state_manager)
dense_tensor, lengths = column.get_sequence_dense_tensor( dense_tensor, lengths = column.get_sequence_dense_tensor(
fc.FeatureTransformationCache(features), state_manager) fc.FeatureTransformationCache(features), state_manager)