Add support to the TPUEmbedding classes for sequence features.
PiperOrigin-RevId: 243712953
This commit is contained in:
parent
b67614f20e
commit
85709f9d72
@ -115,7 +115,7 @@ def get_configs_from_feature_columns(feature_columns):
|
||||
|
||||
Returns:
|
||||
A tuple of dicts, the first maps tables to their config, the second maps
|
||||
features to tables, and the third maps features to weight key names.
|
||||
features to their config, and the third maps features to weight key names.
|
||||
"""
|
||||
|
||||
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
|
||||
@ -127,17 +127,18 @@ def get_configs_from_feature_columns(feature_columns):
|
||||
type(column), allowed))
|
||||
|
||||
table_to_config = {}
|
||||
feature_to_table = {}
|
||||
feature_to_config = {}
|
||||
feature_to_weight_key_name = {}
|
||||
for column in feature_columns:
|
||||
feature_name = column.get_feature_key_name()
|
||||
table_name = _get_table_name_from_embedding_var_name(
|
||||
column.get_embedding_var_name())
|
||||
if feature_name in feature_to_table:
|
||||
if feature_name in feature_to_config:
|
||||
raise ValueError(
|
||||
'Feature column {} is used with multiple embeddings and this is '
|
||||
'not supported.'.format(feature_name))
|
||||
feature_to_table[feature_name] = table_name
|
||||
feature_to_config[feature_name] = tpu_embedding.FeatureConfig(
|
||||
table_id=table_name)
|
||||
feature_to_weight_key_name[feature_name] = column.get_weight_key_name()
|
||||
vocabulary_size, dimension = column.get_embedding_table_size()
|
||||
table_to_config[table_name] = tpu_embedding.TableConfig(
|
||||
@ -146,7 +147,7 @@ def get_configs_from_feature_columns(feature_columns):
|
||||
initializer=column.get_initializer(),
|
||||
combiner=column.get_combiner())
|
||||
|
||||
return table_to_config, feature_to_table, feature_to_weight_key_name
|
||||
return table_to_config, feature_to_config, feature_to_weight_key_name
|
||||
|
||||
|
||||
class EmbeddingConfigSpec(
|
||||
@ -238,7 +239,7 @@ class EmbeddingConfig(object):
|
||||
self._num_cores = num_cores
|
||||
self._run_config = run_config
|
||||
|
||||
(self._table_to_config_dict, self._feature_to_table_dict,
|
||||
(self._table_to_config_dict, self._feature_to_config_dict,
|
||||
self.feature_to_weight_key_name_dict) = (
|
||||
get_configs_from_feature_columns(
|
||||
embedding_config_spec.feature_columns))
|
||||
@ -286,7 +287,7 @@ class EmbeddingConfig(object):
|
||||
cluster_def = None
|
||||
tpu_embedding_ = tpu_embedding.TPUEmbedding(
|
||||
self._table_to_config_dict,
|
||||
self._feature_to_table_dict,
|
||||
self._feature_to_config_dict,
|
||||
batch_size,
|
||||
tpu_embedding_mode,
|
||||
master,
|
||||
@ -310,7 +311,7 @@ def split_inputs(ctx, features, labels):
|
||||
tpu_embedding_ = ctx.embedding_config.tpu_embedding
|
||||
feature_to_weight_key_name_dict = (
|
||||
ctx.embedding_config.feature_to_weight_key_name_dict)
|
||||
for feature_key in tpu_embedding_.feature_to_table_dict:
|
||||
for feature_key in tpu_embedding_.feature_to_config_dict:
|
||||
sparse_feature = _get_sparse_feature_from_feature(feature_key, features)
|
||||
weight_key_name = feature_to_weight_key_name_dict[feature_key]
|
||||
if isinstance(sparse_feature, sparse_tensor.SparseTensor):
|
||||
|
@ -96,6 +96,37 @@ class TableConfig(
|
||||
initializer, combiner)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
collections.namedtuple(
|
||||
'FeatureConfig',
|
||||
['table_id', 'max_sequence_length'])):
|
||||
"""Feature configuration."""
|
||||
|
||||
def __new__(cls,
|
||||
table_id,
|
||||
max_sequence_length=0):
|
||||
"""Feature configuration.
|
||||
|
||||
Args:
|
||||
table_id: Which table the feature is uses for embedding lookups.
|
||||
max_sequence_length: If positive, the feature is a sequence feature with
|
||||
the corresponding maximum sequence length. If the sequence is longer
|
||||
than this, it will be truncated. If 0, the feature is not a sequence
|
||||
feature.
|
||||
|
||||
Returns:
|
||||
`FeatureConfig`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `max_sequence_length` non-negative.
|
||||
"""
|
||||
if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
|
||||
raise ValueError('Invalid max_sequence_length {}.'.format(
|
||||
max_sequence_length))
|
||||
|
||||
return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length)
|
||||
|
||||
|
||||
class EnqueueData(
|
||||
collections.namedtuple(
|
||||
'EnqueueData',
|
||||
@ -281,15 +312,15 @@ class TPUEmbedding(object):
|
||||
initializer=initializer, combiner='mean')
|
||||
table_to_config_dict = {'video': table_config_video,
|
||||
'user': table_config_user}
|
||||
feature_to_table_dict = {'watched': 'video',
|
||||
'favorited': 'video',
|
||||
'friends': 'user'}
|
||||
feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'),
|
||||
'favorited': tpu_embedding.FeatureConfig('video'),
|
||||
'friends': tpu_embedding.FeatureConfig('user')}
|
||||
batch_size = 4
|
||||
num_hosts = 1
|
||||
optimization_parameters = tpu_embedding.AdagradParameters(1., 1.)
|
||||
mode = tpu_embedding.TRAINING
|
||||
embedding = tpu_embedding.TPUEmbedding(
|
||||
table_to_config_dict, feature_to_table_dict,
|
||||
table_to_config_dict, feature_to_config_dict,
|
||||
batch_size, num_hosts, mode, optimization_parameters)
|
||||
|
||||
batch_size_per_core = embedding.batch_size_per_core
|
||||
@ -336,17 +367,16 @@ class TPUEmbedding(object):
|
||||
```
|
||||
"""
|
||||
|
||||
# TODO(shizhiw): Instead of `feature_to_table_dict` which maps to table
|
||||
# name, consider `feature_to_config_dict` which maps to `FeatureConfig`.
|
||||
# `FeatureConfig` could have fields other than table name. For example, it
|
||||
# could have a field to indicate that the feature should not be used to
|
||||
# update embedding table (cr/204852758, cr/204940540). Also, this can support
|
||||
# different combiners for different features within the same table.
|
||||
# TODO(shizhiw): Consider addign a field to FeatureConfig that indicates that
|
||||
# the feature should not be used to update embedding table (cr/204852758,
|
||||
# cr/204940540). Also, this can support different combiners for different
|
||||
# features within the same table.
|
||||
# TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it
|
||||
# to `FeatureConfig`?
|
||||
|
||||
# TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and
|
||||
# `feature_to_table_dict` lists of `TableSpec` and `FeatureSpec` respectively?
|
||||
# `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec`
|
||||
# respectively?
|
||||
|
||||
# TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate
|
||||
# for-loops around construction of inputs.
|
||||
@ -356,7 +386,7 @@ class TPUEmbedding(object):
|
||||
# global setting.
|
||||
def __init__(self,
|
||||
table_to_config_dict,
|
||||
feature_to_table_dict,
|
||||
feature_to_config_dict,
|
||||
batch_size,
|
||||
mode,
|
||||
master,
|
||||
@ -369,9 +399,9 @@ class TPUEmbedding(object):
|
||||
table_to_config_dict: A dictionary mapping from string of table name to
|
||||
`TableConfig`. Table refers to an embedding table, e.g. `params`
|
||||
argument to `tf.nn.embedding_lookup_sparse()`.
|
||||
feature_to_table_dict: A dictionary mapping from string of feature name
|
||||
to string of table name. Feature refers to ids to lookup in embedding
|
||||
table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
|
||||
feature_to_config_dict: A dictionary mapping from string of feature name
|
||||
to `FeatureConfig`. Feature refers to ids to lookup in embedding table,
|
||||
e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
|
||||
batch_size: An `int` representing the global batch size.
|
||||
mode: `TRAINING` or `INFERENCE`.
|
||||
master: A `string` representing the TensorFlow master to use.
|
||||
@ -391,10 +421,12 @@ class TPUEmbedding(object):
|
||||
# Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
|
||||
self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
|
||||
|
||||
_validate_feature_to_table_dict(table_to_config_dict, feature_to_table_dict)
|
||||
self._feature_to_table_dict = _create_ordered_dict(feature_to_table_dict)
|
||||
self._table_to_features_dict = _create_table_to_features_dict(
|
||||
self._feature_to_table_dict)
|
||||
_validate_feature_to_config_dict(table_to_config_dict,
|
||||
feature_to_config_dict)
|
||||
self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict)
|
||||
self._table_to_features_dict, self._table_to_num_features_dict = (
|
||||
_create_table_to_features_and_num_features_dicts(
|
||||
self._feature_to_config_dict))
|
||||
self._combiners = _create_combiners(self._table_to_config_dict,
|
||||
self._table_to_features_dict)
|
||||
|
||||
@ -504,8 +536,8 @@ class TPUEmbedding(object):
|
||||
return copy.copy(self._table_to_config_dict)
|
||||
|
||||
@property
|
||||
def feature_to_table_dict(self):
|
||||
return copy.copy(self._feature_to_table_dict)
|
||||
def feature_to_config_dict(self):
|
||||
return copy.copy(self._feature_to_config_dict)
|
||||
|
||||
@property
|
||||
def table_to_features_dict(self):
|
||||
@ -526,8 +558,7 @@ class TPUEmbedding(object):
|
||||
table_descriptor.vocabulary_size = table_config.vocabulary_size
|
||||
table_descriptor.dimension = table_config.dimension
|
||||
|
||||
features_for_table = self._table_to_features_dict[table]
|
||||
table_descriptor.num_features = len(features_for_table)
|
||||
table_descriptor.num_features = self._table_to_num_features_dict[table]
|
||||
|
||||
table_descriptor.optimization_parameters.learning_rate.constant = (
|
||||
self._optimization_parameters.learning_rate)
|
||||
@ -652,7 +683,7 @@ class TPUEmbedding(object):
|
||||
def _validate_generate_enqueue_ops_enqueue_datas_list(self,
|
||||
enqueue_datas_list):
|
||||
"""Validate `enqueue_datas_list`."""
|
||||
feature_set = set(self._feature_to_table_dict.keys())
|
||||
feature_set = set(self._feature_to_config_dict.keys())
|
||||
contiguous_device = None
|
||||
for i, enqueue_datas in enumerate(enqueue_datas_list):
|
||||
used_feature_set = set(enqueue_datas.keys())
|
||||
@ -674,7 +705,7 @@ class TPUEmbedding(object):
|
||||
device_feature = None
|
||||
for feature, enqueue_data in six.iteritems(enqueue_datas):
|
||||
combiner = self._table_to_config_dict[
|
||||
self._feature_to_table_dict[feature]].combiner
|
||||
self._feature_to_config_dict[feature].table_id].combiner
|
||||
if not isinstance(enqueue_data, EnqueueData):
|
||||
raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
|
||||
'not mapped to `EnqueueData`. `feature`: {}'.format(
|
||||
@ -726,7 +757,7 @@ class TPUEmbedding(object):
|
||||
enqueue_data0 = list(enqueue_datas.values())[0]
|
||||
with ops.colocate_with(enqueue_data0.embedding_indices):
|
||||
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
|
||||
table_ids) = (
|
||||
table_ids, max_sequence_lengths) = (
|
||||
self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
|
||||
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
|
||||
sample_indices_list,
|
||||
@ -734,7 +765,8 @@ class TPUEmbedding(object):
|
||||
aggregation_weights_list,
|
||||
table_ids,
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=self._combiners)
|
||||
combiners=self._combiners,
|
||||
max_sequence_lengths=max_sequence_lengths)
|
||||
|
||||
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
|
||||
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
||||
@ -748,7 +780,7 @@ class TPUEmbedding(object):
|
||||
"""
|
||||
|
||||
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
|
||||
table_ids) = [], [], [], []
|
||||
table_ids, max_sequence_lengths) = [], [], [], [], []
|
||||
for table_id, table in enumerate(self._table_to_features_dict):
|
||||
features = self._table_to_features_dict[table]
|
||||
for feature in features:
|
||||
@ -769,9 +801,11 @@ class TPUEmbedding(object):
|
||||
embedding_indices_list.append(enqueue_data.embedding_indices)
|
||||
|
||||
table_ids.append(table_id)
|
||||
max_sequence_lengths.append(
|
||||
self._feature_to_config_dict[feature].max_sequence_length)
|
||||
|
||||
return (sample_indices_list, embedding_indices_list,
|
||||
aggregation_weights_list, table_ids)
|
||||
aggregation_weights_list, table_ids, max_sequence_lengths)
|
||||
|
||||
def get_activations(self):
|
||||
"""Get activations for features.
|
||||
@ -790,9 +824,21 @@ class TPUEmbedding(object):
|
||||
activations = collections.OrderedDict()
|
||||
for table_id, table in enumerate(self._table_to_features_dict):
|
||||
features = self._table_to_features_dict[table]
|
||||
for lookup_id, feature in enumerate(features):
|
||||
stride = len(self._table_to_features_dict[table])
|
||||
activations[feature] = recv_activations[table_id][lookup_id::stride, :]
|
||||
num_features = self._table_to_num_features_dict[table]
|
||||
feature_index = 0
|
||||
table_activations = array_ops.reshape(
|
||||
recv_activations[table_id],
|
||||
[self.batch_size_per_core, num_features, -1])
|
||||
for feature in features:
|
||||
seq_length = self._feature_to_config_dict[feature].max_sequence_length
|
||||
if not seq_length:
|
||||
activations[feature] = table_activations[:, feature_index, :]
|
||||
feature_index = feature_index + 1
|
||||
else:
|
||||
activations[feature] = (
|
||||
table_activations[:, feature_index:(feature_index+seq_length), :])
|
||||
feature_index = feature_index + seq_length
|
||||
|
||||
return activations
|
||||
|
||||
def generate_send_gradients_op(self, feature_to_gradient_dict):
|
||||
@ -815,12 +861,16 @@ class TPUEmbedding(object):
|
||||
gradients = []
|
||||
for table in self._table_to_features_dict:
|
||||
features = self._table_to_features_dict[table]
|
||||
table_gradients = [
|
||||
feature_to_gradient_dict[feature] for feature in features
|
||||
]
|
||||
table_gradients = []
|
||||
for feature in features:
|
||||
gradient = feature_to_gradient_dict[feature]
|
||||
# Expand dims for non-sequence feature to match sequence features.
|
||||
if gradient.shape.ndims == 2:
|
||||
gradient = array_ops.expand_dims(gradient, 1)
|
||||
table_gradients.append(gradient)
|
||||
interleaved_table_grads = array_ops.reshape(
|
||||
array_ops.stack(table_gradients, axis=1),
|
||||
[-1, table_gradients[0].shape[1]])
|
||||
array_ops.concat(table_gradients, axis=1),
|
||||
[-1, table_gradients[0].shape[-1]])
|
||||
gradients.append(interleaved_table_grads)
|
||||
return tpu_ops.send_tpu_embedding_gradients(
|
||||
inputs=gradients, config=self.config_proto.SerializeToString())
|
||||
@ -834,21 +884,22 @@ def _validate_table_to_config_dict(table_to_config_dict):
|
||||
'`TableConfig`, got {} for {}.'.format(type(v), k))
|
||||
|
||||
|
||||
def _validate_feature_to_table_dict(table_to_config_dict,
|
||||
feature_to_table_dict):
|
||||
"""Validate `feature_to_table_dict`."""
|
||||
used_table_set = set(feature_to_table_dict.values())
|
||||
def _validate_feature_to_config_dict(table_to_config_dict,
|
||||
feature_to_config_dict):
|
||||
"""Validate `feature_to_config_dict`."""
|
||||
used_table_set = set([feature.table_id
|
||||
for feature in feature_to_config_dict.values()])
|
||||
table_set = set(table_to_config_dict.keys())
|
||||
|
||||
unused_table_set = table_set - used_table_set
|
||||
if unused_table_set:
|
||||
raise ValueError('`table_to_config_dict` specifies table that is not '
|
||||
'used in `feature_to_table_dict`: {}.'
|
||||
'used in `feature_to_config_dict`: {}.'
|
||||
.format(unused_table_set))
|
||||
|
||||
extra_table_set = used_table_set - table_set
|
||||
if extra_table_set:
|
||||
raise ValueError('`feature_to_table_dict` refers to a table that is not '
|
||||
raise ValueError('`feature_to_config_dict` refers to a table that is not '
|
||||
'specified in `table_to_config_dict`: {}.'
|
||||
.format(extra_table_set))
|
||||
|
||||
@ -1135,19 +1186,30 @@ def _create_combiners(table_to_config_dict, table_to_features_dict):
|
||||
return combiners
|
||||
|
||||
|
||||
def _create_table_to_features_dict(feature_to_table_dict):
|
||||
def _create_table_to_features_and_num_features_dicts(feature_to_config_dict):
|
||||
"""Create mapping from table to a list of its features."""
|
||||
table_to_features_dict_tmp = {}
|
||||
for feature, table in six.iteritems(feature_to_table_dict):
|
||||
if table in table_to_features_dict_tmp:
|
||||
table_to_features_dict_tmp[table].append(feature)
|
||||
table_to_num_features_dict_tmp = {}
|
||||
for feature, feature_config in six.iteritems(feature_to_config_dict):
|
||||
if feature_config.table_id in table_to_features_dict_tmp:
|
||||
table_to_features_dict_tmp[feature_config.table_id].append(feature)
|
||||
else:
|
||||
table_to_features_dict_tmp[table] = [feature]
|
||||
table_to_features_dict_tmp[feature_config.table_id] = [feature]
|
||||
table_to_num_features_dict_tmp[feature_config.table_id] = 0
|
||||
if feature_config.max_sequence_length == 0:
|
||||
table_to_num_features_dict_tmp[feature_config.table_id] = (
|
||||
table_to_num_features_dict_tmp[feature_config.table_id] + 1)
|
||||
else:
|
||||
table_to_num_features_dict_tmp[feature_config.table_id] = (
|
||||
table_to_num_features_dict_tmp[feature_config.table_id] +
|
||||
feature_config.max_sequence_length)
|
||||
|
||||
table_to_features_dict = collections.OrderedDict()
|
||||
table_to_num_features_dict = collections.OrderedDict()
|
||||
for table in sorted(table_to_features_dict_tmp):
|
||||
table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table])
|
||||
return table_to_features_dict
|
||||
table_to_num_features_dict[table] = table_to_num_features_dict_tmp[table]
|
||||
return table_to_features_dict, table_to_num_features_dict
|
||||
|
||||
|
||||
def _create_device_fn(hosts):
|
||||
|
@ -115,7 +115,7 @@ def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
|
||||
"""
|
||||
new_activations = collections.OrderedDict()
|
||||
for feature in activations:
|
||||
table = tpu_embedding.feature_to_table_dict[feature]
|
||||
table = tpu_embedding.feature_to_config_dict[feature].table_id
|
||||
new_activations[feature] = tpu_ops.tpu_embedding_activations(
|
||||
dummy_table_variables[table],
|
||||
activations[feature],
|
||||
|
Loading…
x
Reference in New Issue
Block a user