Support aggregation weights for embedding lookup with TPU in TPUEstimator.

PiperOrigin-RevId: 241827445
This commit is contained in:
A. Unique TensorFlower 2019-04-03 16:18:16 -07:00 committed by TensorFlower Gardener
parent bf3bd1c026
commit bb5880c426
3 changed files with 212 additions and 75 deletions

View File

@ -20,13 +20,13 @@ from __future__ import print_function
import collections
import six
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.feature_column import feature_column as core_fc
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.tpu import feature_column as tpu_fc
from tensorflow.python.tpu import tpu_embedding
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
@ -107,18 +107,15 @@ def get_full_variable_names(
return embedding_variable_name_by_table, slot_variable_names_by_table
def get_tpu_embedding_config_from_feature_columns(feature_columns):
"""Create configs for TPUEmbedding from a list of feature columns.
This function will place one embedding tensor per table and the return is
intended to be used as input to TPUEmbedding.
def get_configs_from_feature_columns(feature_columns):
"""Create configs for TPUEmbedding etc from a list of feature columns.
Args:
feature_columns: a list of supported feature columns.
Returns:
A pair of dicts, the first maps tables to their config, the second maps
features to tables.
A tuple of dicts, the first maps tables to their config, the second maps
features to tables, and the third maps features to weight key names.
"""
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
@ -131,6 +128,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
table_to_config = {}
feature_to_table = {}
feature_to_weight_key_name = {}
for column in feature_columns:
feature_name = column.get_feature_key_name()
table_name = _get_table_name_from_embedding_var_name(
@ -140,6 +138,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
'Feature column {} is used with multiple embeddings and this is '
'not supported.'.format(feature_name))
feature_to_table[feature_name] = table_name
feature_to_weight_key_name[feature_name] = column.get_weight_key_name()
vocabulary_size, dimension = column.get_embedding_table_size()
table_to_config[table_name] = tpu_embedding.TableConfig(
vocabulary_size=vocabulary_size,
@ -147,7 +146,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
initializer=column.get_initializer(),
combiner=column.get_combiner())
return table_to_config, feature_to_table
return table_to_config, feature_to_table, feature_to_weight_key_name
class EmbeddingConfigSpec(
@ -239,9 +238,10 @@ class EmbeddingConfig(object):
self._num_cores = num_cores
self._run_config = run_config
self._table_to_config_dict, self._feature_to_table_dict = (
get_tpu_embedding_config_from_feature_columns(
embedding_config_spec.feature_columns))
(self._table_to_config_dict, self._feature_to_table_dict,
self.feature_to_weight_key_name_dict) = (
get_configs_from_feature_columns(
embedding_config_spec.feature_columns))
self._mode_to_tpu_embedding_dict = {}
self.dummy_table_variables = None
@ -305,19 +305,61 @@ class EmbeddingConfig(object):
def split_inputs(ctx, features, labels):
"""Splits the dense and sparse tensors inside the features and labels."""
sparse_features = collections.OrderedDict()
enqueue_datas = collections.OrderedDict()
if ctx.embedding_config:
tpu_embedding_ = ctx.embedding_config.tpu_embedding
feature_to_weight_key_name_dict = (
ctx.embedding_config.feature_to_weight_key_name_dict)
for feature_key in tpu_embedding_.feature_to_table_dict:
sparse_features[feature_key] = features.pop(feature_key)
sparse_feature = _get_sparse_feature_from_feature(feature_key, features)
weight_key_name = feature_to_weight_key_name_dict[feature_key]
if isinstance(sparse_feature, sparse_tensor.SparseTensor):
weights = _get_weights_from_features(weight_key_name, features)
enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor(
sparse_feature, weights)
else:
if weight_key_name is not None:
raise ValueError(
'Found weights {} for weighted_categorical_column, which is not'
'compatible with sparse feature {} enqueued as dense tensor.'
.format(weight_key_name, feature_key))
enqueue_data = tpu_embedding.EnqueueData(sparse_feature)
enqueue_datas[feature_key] = enqueue_data
for v in six.itervalues(sparse_features):
if not v.dtype.is_integer:
raise ValueError('SparseTensor with string as values are not supported. '
'If you are using vocabulary_file_categorical_column or '
'vocabulary_list_categorical_column, please call '
'your_column.categorical_column._transform_feature({'
'your_column.key: features[your_column.key]}) in'
'your input_fn() to convert string to int.')
return features, labels, enqueue_datas
return features, labels, sparse_features
def _get_sparse_feature_from_feature(feature_key, features):
"""Pop and return sparse feature."""
sparse_feature = features.pop(feature_key)
if not sparse_feature.dtype.is_integer:
raise ValueError('SparseTensor with string as values are not supported. '
'If you are using vocabulary_file_categorical_column or '
'vocabulary_list_categorical_column, please call '
'your_column.categorical_column._transform_feature({{'
'your_column.key: features[your_column.key]}}) in'
'your input_fn() to convert string to int. '
'feature_key = {}.'.format(feature_key))
return sparse_feature
def _get_weights_from_features(weight_key_name, features):
"""Pop and return feature for weights, possibly None."""
weights = None
if weight_key_name is not None:
if weight_key_name in features:
weights = features.pop(weight_key_name)
else:
raise ValueError(
'Cannot find weights {} for weighted_categorical_column.'
' Please check if the weights are present in feature dict. Also'
' note weight-sharing among weighted_categorical_column is not '
'supported on TPU.'.format(weight_key_name))
if not isinstance(weights, sparse_tensor.SparseTensor):
raise ValueError(
'weighted_categorical_column with weight key name {} has dense '
'weights. Dense weights are not supported on TPU. Please use '
'sparse weights instead.'.format(weight_key_name))
if weights.dtype is not dtypes.float32:
weights = math_ops.to_float(weights)
return weights

View File

@ -28,7 +28,6 @@ from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
@ -97,6 +96,73 @@ class TableConfig(
initializer, combiner)
class EnqueueData(
collections.namedtuple(
'EnqueueData',
['embedding_indices', 'sample_indices', 'aggregation_weights'])):
"""Data to be enqueued through generate_enqueue_ops()."""
def __new__(cls,
embedding_indices,
sample_indices=None,
aggregation_weights=None):
"""Data to be enqueued through generate_enqueue_ops().
Args:
embedding_indices: A rank 1 Tensors, indices into the embedding tables. It
corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
and int64 are allowed and will be converted to int32 internally.
sample_indices: A rank 2 Tensors specifying the training example to which
the corresponding embedding_indices and aggregation_weights values
belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
If it is None, we assume each embedding_indices belongs to a different
sample. Both int32 and int64 are allowed and will be converted to int32
internally.
aggregation_weights: A rank 1 Tensors containing per training example
aggregation weights. It corresponds to sp_weights.values in
embedding_lookup_sparse(). If it is None, we assume all weights are 1.
Both float32 and float64 are allowed and will be converted to float32
internally.
Returns:
An EnqueueData tuple.
"""
return super(EnqueueData, cls).__new__(cls, embedding_indices,
sample_indices, aggregation_weights)
@staticmethod
def from_sparse_tensor(sp_tensor, weights=None):
return EnqueueData(
sp_tensor.values,
sp_tensor.indices,
aggregation_weights=weights.values if weights is not None else None)
def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
"""Convenient function for generate_enqueue_ops().
Args:
sp_tensors_list: a list of dictionary mapping from string of feature names
to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the
same host should be contiguous on the list.
Returns:
enqueue_datas_list: a list of dictionary mapping from string
of feature names to EnqueueData. Each dictionary is for one
TPU core. Dictionaries for the same host should be contiguous
on the list.
"""
enqueue_datas_list = []
for sp_tensors in sp_tensors_list:
enqueue_datas = collections.OrderedDict(
(k, EnqueueData.from_sparse_tensor(v))
for k, v in six.iteritems(sp_tensors))
enqueue_datas_list.append(enqueue_datas)
return enqueue_datas_list
AdamSlotVariableNames = collections.namedtuple(
'AdamSlotVariableNames', ['m', 'v'])
@ -564,119 +630,148 @@ class TPUEmbedding(object):
slot_variables_by_table,
load_ops, retrieve_ops)
def generate_enqueue_ops(self, sparse_features_list):
def generate_enqueue_ops(self, enqueue_datas_list):
"""Generate enqueue ops.
Args:
sparse_features_list: a list of dictionary mapping from string
of feature names to sparse tensor. Each dictionary is for one
enqueue_datas_list: a list of dictionary mapping from string
of feature names to EnqueueData. Each dictionary is for one
TPU core. Dictionaries for the same host should be contiguous
on the list.
Returns:
Ops to enqueue to TPU for embedding.
"""
self._validate_generate_enqueue_ops_sparse_features_list(
sparse_features_list)
self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list)
return [
self._generate_enqueue_op(
sparse_features, device_ordinal=i % self._num_cores_per_host)
for i, sparse_features in enumerate(sparse_features_list)
enqueue_datas, device_ordinal=i % self._num_cores_per_host)
for i, enqueue_datas in enumerate(enqueue_datas_list)
]
def _validate_generate_enqueue_ops_sparse_features_list(
self, sparse_features_list):
"""Validate `sparse_features_list`."""
def _validate_generate_enqueue_ops_enqueue_datas_list(self,
enqueue_datas_list):
"""Validate `enqueue_datas_list`."""
feature_set = set(self._feature_to_table_dict.keys())
contiguous_device = None
for i, sparse_features in enumerate(sparse_features_list):
used_feature_set = set(sparse_features.keys())
for i, enqueue_datas in enumerate(enqueue_datas_list):
used_feature_set = set(enqueue_datas.keys())
# Check features are valid.
missing_feature_set = feature_set - used_feature_set
if missing_feature_set:
raise ValueError('`sparse_features_list[{}]` misses a feature that is '
raise ValueError('`enqueue_datas_list[{}]` misses a feature that is '
'in `feature_to_config_dict`: {}.'.format(
i, missing_feature_set))
extra_feature_set = used_feature_set - feature_set
if extra_feature_set:
raise ValueError('`sparse_features_list[{}]` has a feature that is not '
raise ValueError('`enqueue_datas_list[{}]` has a feature that is not '
'in `feature_to_config_dict`: {}.'.format(
i, extra_feature_set))
device = None
device_feature = None
for feature, tensor in six.iteritems(sparse_features):
for feature, enqueue_data in six.iteritems(enqueue_datas):
combiner = self._table_to_config_dict[
self._feature_to_table_dict[feature]].combiner
if not isinstance(tensor, sparse_tensor.SparseTensor) and combiner:
raise ValueError('`sparse_features_list[{}]` has a feature that is '
'not mapped to `SparseTensor` and has a combiner. '
'`feature`: {}, combiner: {}'.format(
if not isinstance(enqueue_data, EnqueueData):
raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
'not mapped to `EnqueueData`. `feature`: {}'.format(
i, feature))
if enqueue_data.sample_indices is None and combiner:
raise ValueError('`enqueue_datas_list[{}]` has a feature that has '
'neither `EnqueueData` or `combiner`.'
'`feature`: {}, combiner: {}.'.format(
i, feature, combiner))
if (enqueue_data.sample_indices is not None and
enqueue_data.sample_indices.op.device !=
enqueue_data.embedding_indices.op.device):
raise ValueError(
'Device of sample_indices does not agree with '
'that of emebdding_indices for feature {}.'.format(feature))
if (enqueue_data.aggregation_weights is not None and
enqueue_data.aggregation_weights.op.device !=
enqueue_data.embedding_indices.op.device):
raise ValueError(
'Device of aggregation_weights does not agree with '
'that of emebdding_indices for feature {}.'.format(feature))
# Check all features are on the same device.
if device is None:
device = tensor.op.device
device = enqueue_data.embedding_indices.op.device
device_feature = feature
else:
if device != tensor.op.device:
if device != enqueue_data.embedding_indices.op.device:
raise ValueError('Devices are different between features in '
'`sparse_features_list[{}]`; '
'`enqueue_datas_list[{}]`; '
'devices: {}, {}; features: {}, {}.'.format(
i, device, tensor.op.device, feature,
device_feature))
i, device,
enqueue_data.embedding_indices.op.device,
feature, device_feature))
if i % self._num_cores_per_host:
if device != contiguous_device:
raise ValueError('We expect the `sparse_features` which are on the '
raise ValueError('We expect the `enqueue_datas` which are on the '
'same host to be contiguous in '
'`sparse_features_list`, '
'`sparse_features_list[{}]` is on device {}, '
'`enqueue_datas_list`, '
'`enqueue_datas_list[{}]` is on device {}, '
'but is expected to be on device {}.'.format(
i, device, contiguous_device))
else:
contiguous_device = device
def _generate_enqueue_op(self, sparse_features, device_ordinal):
with ops.colocate_with(list(sparse_features.values())[0]):
sample_idcs, embedding_idcs, aggregation_weights, table_ids = (
self._format_for_tpu_embedding_sparse_tensor_batch(sparse_features))
def _generate_enqueue_op(self, enqueue_datas, device_ordinal):
enqueue_data0 = list(enqueue_datas.values())[0]
with ops.colocate_with(enqueue_data0.embedding_indices):
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
table_ids) = (
self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
sample_idcs,
embedding_idcs,
aggregation_weights,
sample_indices_list,
embedding_indices_list,
aggregation_weights_list,
table_ids,
device_ordinal=device_ordinal,
combiners=self._combiners)
def _format_for_tpu_embedding_sparse_tensor_batch(self, sparse_features):
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
Args:
sparse_features: a `Dict` of tensors for embedding. Can be sparse or
enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
dense.
Returns:
Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
"""
sample_idcs, embedding_idcs, aggregation_weights, table_ids = [], [], [], []
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
table_ids) = [], [], [], []
for table_id, table in enumerate(self._table_to_features_dict):
features = self._table_to_features_dict[table]
for feature in features:
tensor = sparse_features[feature]
if not isinstance(tensor, sparse_tensor.SparseTensor):
sample_idcs.append(array_ops.zeros([0], dtype=dtypes.int32))
embedding_idcs.append(tensor)
else:
sample_idcs.append(tensor.indices)
embedding_idcs.append(tensor.values)
aggregation_weights.append(array_ops.zeros([0]))
enqueue_data = enqueue_datas[feature]
sample_indices = (
enqueue_data.sample_indices
if enqueue_data.sample_indices is not None else array_ops.zeros(
(0,), dtype=dtypes.int32))
sample_indices_list.append(sample_indices)
aggregation_weights = (
enqueue_data.aggregation_weights if
enqueue_data.aggregation_weights is not None else array_ops.zeros(
(0,), dtype=dtypes.float32))
aggregation_weights_list.append(aggregation_weights)
embedding_indices_list.append(enqueue_data.embedding_indices)
table_ids.append(table_id)
return sample_idcs, embedding_idcs, aggregation_weights, table_ids
return (sample_indices_list, embedding_indices_list,
aggregation_weights_list, table_ids)
def get_activations(self):
"""Get activations for features.

View File

@ -891,7 +891,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
"""Generates the per_host enqueue ops."""
control_deps = []
per_host_sharded_inputs = []
sparse_features_list = []
enqueue_datas_list = []
num_replicas_per_host = ctx.num_of_replicas_per_host
cached_signals = None
with ops.device(device):
@ -910,9 +910,9 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
else:
cached_signals = signals
features, labels, sparse_features = (
features, labels, enqueue_data = (
_tpu_estimator_embedding.split_inputs(ctx, features, labels))
sparse_features_list.append(sparse_features)
enqueue_datas_list.append(enqueue_data)
inputs_structure_recorder.validate_and_record_structure(
features, labels)
@ -945,7 +945,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
if ctx.embedding_config:
per_host_enqueue_ops.extend(
ctx.embedding_config.tpu_embedding.generate_enqueue_ops(
sparse_features_list))
enqueue_datas_list))
if signals is None:
return per_host_enqueue_ops