Add support for TPUEstimator to use Feature Column V2 versions of the TPUEmbedding Columns.
PiperOrigin-RevId: 253798577
This commit is contained in:
parent
4a7746d3cc
commit
6ef00b2eb2
@ -531,3 +531,39 @@ class _TPUSharedEmbeddingColumnV2(_TPUBaseEmbeddingColumn,
|
||||
|
||||
return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
|
||||
dense_tensor=tensor, sequence_length=tensor_lengths)
|
||||
|
||||
|
||||
def split_sequence_columns_v2(feature_columns):
|
||||
"""Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns.
|
||||
|
||||
For use in a TPUEstimator model_fn function. E.g.
|
||||
|
||||
def model_fn(features):
|
||||
sequence_columns, feature_columns = (
|
||||
tf.tpu.feature_column.split_sequence_columns(feature_columns))
|
||||
input = tf.feature_column.input_layer(
|
||||
features=features, feature_columns=feature_columns)
|
||||
sequence_features, sequence_lengths = (
|
||||
tf.contrib.feature_column.sequence_input_layer(
|
||||
features=features, feature_columns=sequence_columns))
|
||||
|
||||
Args:
|
||||
feature_columns: A list of _TPUEmbeddingColumns to split.
|
||||
|
||||
Returns:
|
||||
Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the
|
||||
second is the non-sequence columns.
|
||||
"""
|
||||
sequence_columns = []
|
||||
non_sequence_columns = []
|
||||
for column in feature_columns:
|
||||
if not isinstance(column, (_TPUEmbeddingColumnV2,
|
||||
_TPUSharedEmbeddingColumnV2)):
|
||||
raise TypeError(
|
||||
'column must be a _TPUEmbeddingColumnV2 or '
|
||||
'_TPUSharedEmbeddingColumnV2 but got %s instead.' % (type(column)))
|
||||
if column.is_sequence_column():
|
||||
sequence_columns.append(column)
|
||||
else:
|
||||
non_sequence_columns.append(column)
|
||||
return sequence_columns, non_sequence_columns
|
||||
|
Loading…
Reference in New Issue
Block a user