Add support for TPUEstimator to use Feature Column V2 versions of the TPUEmbedding Columns.

PiperOrigin-RevId: 253798577
This commit is contained in:
Bruce Fontaine 2019-06-18 08:56:53 -07:00 committed by TensorFlower Gardener
parent 4a7746d3cc
commit 6ef00b2eb2

View File

@ -531,3 +531,39 @@ class _TPUSharedEmbeddingColumnV2(_TPUBaseEmbeddingColumn,
return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=tensor, sequence_length=tensor_lengths)
def split_sequence_columns_v2(feature_columns):
"""Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns.
For use in a TPUEstimator model_fn function. E.g.
def model_fn(features):
sequence_columns, feature_columns = (
tf.tpu.feature_column.split_sequence_columns(feature_columns))
input = tf.feature_column.input_layer(
features=features, feature_columns=feature_columns)
sequence_features, sequence_lengths = (
tf.contrib.feature_column.sequence_input_layer(
features=features, feature_columns=sequence_columns))
Args:
feature_columns: A list of _TPUEmbeddingColumns to split.
Returns:
Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the
second is the non-sequence columns.
"""
sequence_columns = []
non_sequence_columns = []
for column in feature_columns:
if not isinstance(column, (_TPUEmbeddingColumnV2,
_TPUSharedEmbeddingColumnV2)):
raise TypeError(
'column must be a _TPUEmbeddingColumnV2 or '
'_TPUSharedEmbeddingColumnV2 but got %s instead.' % (type(column)))
if column.is_sequence_column():
sequence_columns.append(column)
else:
non_sequence_columns.append(column)
return sequence_columns, non_sequence_columns