Support aggregation weights for embedding lookup with TPU in TPUEstimator.
PiperOrigin-RevId: 241827445
This commit is contained in:
parent
bf3bd1c026
commit
bb5880c426
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.feature_column import feature_column as core_fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.tpu import feature_column as tpu_fc
|
||||
from tensorflow.python.tpu import tpu_embedding
|
||||
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
|
||||
@ -107,18 +107,15 @@ def get_full_variable_names(
|
||||
return embedding_variable_name_by_table, slot_variable_names_by_table
|
||||
|
||||
|
||||
def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
||||
"""Create configs for TPUEmbedding from a list of feature columns.
|
||||
|
||||
This function will place one embedding tensor per table and the return is
|
||||
intended to be used as input to TPUEmbedding.
|
||||
def get_configs_from_feature_columns(feature_columns):
|
||||
"""Create configs for TPUEmbedding etc from a list of feature columns.
|
||||
|
||||
Args:
|
||||
feature_columns: a list of supported feature columns.
|
||||
|
||||
Returns:
|
||||
A pair of dicts, the first maps tables to their config, the second maps
|
||||
features to tables.
|
||||
A tuple of dicts, the first maps tables to their config, the second maps
|
||||
features to tables, and the third maps features to weight key names.
|
||||
"""
|
||||
|
||||
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
|
||||
@ -131,6 +128,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
||||
|
||||
table_to_config = {}
|
||||
feature_to_table = {}
|
||||
feature_to_weight_key_name = {}
|
||||
for column in feature_columns:
|
||||
feature_name = column.get_feature_key_name()
|
||||
table_name = _get_table_name_from_embedding_var_name(
|
||||
@ -140,6 +138,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
||||
'Feature column {} is used with multiple embeddings and this is '
|
||||
'not supported.'.format(feature_name))
|
||||
feature_to_table[feature_name] = table_name
|
||||
feature_to_weight_key_name[feature_name] = column.get_weight_key_name()
|
||||
vocabulary_size, dimension = column.get_embedding_table_size()
|
||||
table_to_config[table_name] = tpu_embedding.TableConfig(
|
||||
vocabulary_size=vocabulary_size,
|
||||
@ -147,7 +146,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns):
|
||||
initializer=column.get_initializer(),
|
||||
combiner=column.get_combiner())
|
||||
|
||||
return table_to_config, feature_to_table
|
||||
return table_to_config, feature_to_table, feature_to_weight_key_name
|
||||
|
||||
|
||||
class EmbeddingConfigSpec(
|
||||
@ -239,8 +238,9 @@ class EmbeddingConfig(object):
|
||||
self._num_cores = num_cores
|
||||
self._run_config = run_config
|
||||
|
||||
self._table_to_config_dict, self._feature_to_table_dict = (
|
||||
get_tpu_embedding_config_from_feature_columns(
|
||||
(self._table_to_config_dict, self._feature_to_table_dict,
|
||||
self.feature_to_weight_key_name_dict) = (
|
||||
get_configs_from_feature_columns(
|
||||
embedding_config_spec.feature_columns))
|
||||
self._mode_to_tpu_embedding_dict = {}
|
||||
self.dummy_table_variables = None
|
||||
@ -305,19 +305,61 @@ class EmbeddingConfig(object):
|
||||
|
||||
def split_inputs(ctx, features, labels):
|
||||
"""Splits the dense and sparse tensors inside the features and labels."""
|
||||
sparse_features = collections.OrderedDict()
|
||||
enqueue_datas = collections.OrderedDict()
|
||||
if ctx.embedding_config:
|
||||
tpu_embedding_ = ctx.embedding_config.tpu_embedding
|
||||
feature_to_weight_key_name_dict = (
|
||||
ctx.embedding_config.feature_to_weight_key_name_dict)
|
||||
for feature_key in tpu_embedding_.feature_to_table_dict:
|
||||
sparse_features[feature_key] = features.pop(feature_key)
|
||||
sparse_feature = _get_sparse_feature_from_feature(feature_key, features)
|
||||
weight_key_name = feature_to_weight_key_name_dict[feature_key]
|
||||
if isinstance(sparse_feature, sparse_tensor.SparseTensor):
|
||||
weights = _get_weights_from_features(weight_key_name, features)
|
||||
enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor(
|
||||
sparse_feature, weights)
|
||||
else:
|
||||
if weight_key_name is not None:
|
||||
raise ValueError(
|
||||
'Found weights {} for weighted_categorical_column, which is not'
|
||||
'compatible with sparse feature {} enqueued as dense tensor.'
|
||||
.format(weight_key_name, feature_key))
|
||||
enqueue_data = tpu_embedding.EnqueueData(sparse_feature)
|
||||
enqueue_datas[feature_key] = enqueue_data
|
||||
|
||||
for v in six.itervalues(sparse_features):
|
||||
if not v.dtype.is_integer:
|
||||
return features, labels, enqueue_datas
|
||||
|
||||
|
||||
def _get_sparse_feature_from_feature(feature_key, features):
|
||||
"""Pop and return sparse feature."""
|
||||
sparse_feature = features.pop(feature_key)
|
||||
if not sparse_feature.dtype.is_integer:
|
||||
raise ValueError('SparseTensor with string as values are not supported. '
|
||||
'If you are using vocabulary_file_categorical_column or '
|
||||
'vocabulary_list_categorical_column, please call '
|
||||
'your_column.categorical_column._transform_feature({'
|
||||
'your_column.key: features[your_column.key]}) in'
|
||||
'your input_fn() to convert string to int.')
|
||||
'your_column.categorical_column._transform_feature({{'
|
||||
'your_column.key: features[your_column.key]}}) in'
|
||||
'your input_fn() to convert string to int. '
|
||||
'feature_key = {}.'.format(feature_key))
|
||||
return sparse_feature
|
||||
|
||||
return features, labels, sparse_features
|
||||
|
||||
def _get_weights_from_features(weight_key_name, features):
|
||||
"""Pop and return feature for weights, possibly None."""
|
||||
weights = None
|
||||
if weight_key_name is not None:
|
||||
if weight_key_name in features:
|
||||
weights = features.pop(weight_key_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find weights {} for weighted_categorical_column.'
|
||||
' Please check if the weights are present in feature dict. Also'
|
||||
' note weight-sharing among weighted_categorical_column is not '
|
||||
'supported on TPU.'.format(weight_key_name))
|
||||
if not isinstance(weights, sparse_tensor.SparseTensor):
|
||||
raise ValueError(
|
||||
'weighted_categorical_column with weight key name {} has dense '
|
||||
'weights. Dense weights are not supported on TPU. Please use '
|
||||
'sparse weights instead.'.format(weight_key_name))
|
||||
if weights.dtype is not dtypes.float32:
|
||||
weights = math_ops.to_float(weights)
|
||||
return weights
|
||||
|
@ -28,7 +28,6 @@ from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
|
||||
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -97,6 +96,73 @@ class TableConfig(
|
||||
initializer, combiner)
|
||||
|
||||
|
||||
class EnqueueData(
|
||||
collections.namedtuple(
|
||||
'EnqueueData',
|
||||
['embedding_indices', 'sample_indices', 'aggregation_weights'])):
|
||||
"""Data to be enqueued through generate_enqueue_ops()."""
|
||||
|
||||
def __new__(cls,
|
||||
embedding_indices,
|
||||
sample_indices=None,
|
||||
aggregation_weights=None):
|
||||
"""Data to be enqueued through generate_enqueue_ops().
|
||||
|
||||
Args:
|
||||
embedding_indices: A rank 1 Tensors, indices into the embedding tables. It
|
||||
corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
|
||||
and int64 are allowed and will be converted to int32 internally.
|
||||
sample_indices: A rank 2 Tensors specifying the training example to which
|
||||
the corresponding embedding_indices and aggregation_weights values
|
||||
belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
|
||||
If it is None, we assume each embedding_indices belongs to a different
|
||||
sample. Both int32 and int64 are allowed and will be converted to int32
|
||||
internally.
|
||||
aggregation_weights: A rank 1 Tensors containing per training example
|
||||
aggregation weights. It corresponds to sp_weights.values in
|
||||
embedding_lookup_sparse(). If it is None, we assume all weights are 1.
|
||||
Both float32 and float64 are allowed and will be converted to float32
|
||||
internally.
|
||||
|
||||
Returns:
|
||||
An EnqueueData tuple.
|
||||
|
||||
"""
|
||||
return super(EnqueueData, cls).__new__(cls, embedding_indices,
|
||||
sample_indices, aggregation_weights)
|
||||
|
||||
@staticmethod
|
||||
def from_sparse_tensor(sp_tensor, weights=None):
|
||||
return EnqueueData(
|
||||
sp_tensor.values,
|
||||
sp_tensor.indices,
|
||||
aggregation_weights=weights.values if weights is not None else None)
|
||||
|
||||
|
||||
def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
|
||||
"""Convenient function for generate_enqueue_ops().
|
||||
|
||||
Args:
|
||||
sp_tensors_list: a list of dictionary mapping from string of feature names
|
||||
to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the
|
||||
same host should be contiguous on the list.
|
||||
|
||||
Returns:
|
||||
enqueue_datas_list: a list of dictionary mapping from string
|
||||
of feature names to EnqueueData. Each dictionary is for one
|
||||
TPU core. Dictionaries for the same host should be contiguous
|
||||
on the list.
|
||||
|
||||
"""
|
||||
enqueue_datas_list = []
|
||||
for sp_tensors in sp_tensors_list:
|
||||
enqueue_datas = collections.OrderedDict(
|
||||
(k, EnqueueData.from_sparse_tensor(v))
|
||||
for k, v in six.iteritems(sp_tensors))
|
||||
enqueue_datas_list.append(enqueue_datas)
|
||||
return enqueue_datas_list
|
||||
|
||||
|
||||
AdamSlotVariableNames = collections.namedtuple(
|
||||
'AdamSlotVariableNames', ['m', 'v'])
|
||||
|
||||
@ -564,119 +630,148 @@ class TPUEmbedding(object):
|
||||
slot_variables_by_table,
|
||||
load_ops, retrieve_ops)
|
||||
|
||||
def generate_enqueue_ops(self, sparse_features_list):
|
||||
def generate_enqueue_ops(self, enqueue_datas_list):
|
||||
"""Generate enqueue ops.
|
||||
|
||||
Args:
|
||||
sparse_features_list: a list of dictionary mapping from string
|
||||
of feature names to sparse tensor. Each dictionary is for one
|
||||
enqueue_datas_list: a list of dictionary mapping from string
|
||||
of feature names to EnqueueData. Each dictionary is for one
|
||||
TPU core. Dictionaries for the same host should be contiguous
|
||||
on the list.
|
||||
|
||||
Returns:
|
||||
Ops to enqueue to TPU for embedding.
|
||||
"""
|
||||
self._validate_generate_enqueue_ops_sparse_features_list(
|
||||
sparse_features_list)
|
||||
self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list)
|
||||
return [
|
||||
self._generate_enqueue_op(
|
||||
sparse_features, device_ordinal=i % self._num_cores_per_host)
|
||||
for i, sparse_features in enumerate(sparse_features_list)
|
||||
enqueue_datas, device_ordinal=i % self._num_cores_per_host)
|
||||
for i, enqueue_datas in enumerate(enqueue_datas_list)
|
||||
]
|
||||
|
||||
def _validate_generate_enqueue_ops_sparse_features_list(
|
||||
self, sparse_features_list):
|
||||
"""Validate `sparse_features_list`."""
|
||||
def _validate_generate_enqueue_ops_enqueue_datas_list(self,
|
||||
enqueue_datas_list):
|
||||
"""Validate `enqueue_datas_list`."""
|
||||
feature_set = set(self._feature_to_table_dict.keys())
|
||||
contiguous_device = None
|
||||
for i, sparse_features in enumerate(sparse_features_list):
|
||||
used_feature_set = set(sparse_features.keys())
|
||||
for i, enqueue_datas in enumerate(enqueue_datas_list):
|
||||
used_feature_set = set(enqueue_datas.keys())
|
||||
|
||||
# Check features are valid.
|
||||
missing_feature_set = feature_set - used_feature_set
|
||||
if missing_feature_set:
|
||||
raise ValueError('`sparse_features_list[{}]` misses a feature that is '
|
||||
raise ValueError('`enqueue_datas_list[{}]` misses a feature that is '
|
||||
'in `feature_to_config_dict`: {}.'.format(
|
||||
i, missing_feature_set))
|
||||
|
||||
extra_feature_set = used_feature_set - feature_set
|
||||
if extra_feature_set:
|
||||
raise ValueError('`sparse_features_list[{}]` has a feature that is not '
|
||||
raise ValueError('`enqueue_datas_list[{}]` has a feature that is not '
|
||||
'in `feature_to_config_dict`: {}.'.format(
|
||||
i, extra_feature_set))
|
||||
|
||||
device = None
|
||||
device_feature = None
|
||||
for feature, tensor in six.iteritems(sparse_features):
|
||||
for feature, enqueue_data in six.iteritems(enqueue_datas):
|
||||
combiner = self._table_to_config_dict[
|
||||
self._feature_to_table_dict[feature]].combiner
|
||||
if not isinstance(tensor, sparse_tensor.SparseTensor) and combiner:
|
||||
raise ValueError('`sparse_features_list[{}]` has a feature that is '
|
||||
'not mapped to `SparseTensor` and has a combiner. '
|
||||
'`feature`: {}, combiner: {}'.format(
|
||||
if not isinstance(enqueue_data, EnqueueData):
|
||||
raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
|
||||
'not mapped to `EnqueueData`. `feature`: {}'.format(
|
||||
i, feature))
|
||||
|
||||
if enqueue_data.sample_indices is None and combiner:
|
||||
raise ValueError('`enqueue_datas_list[{}]` has a feature that has '
|
||||
'neither `EnqueueData` or `combiner`.'
|
||||
'`feature`: {}, combiner: {}.'.format(
|
||||
i, feature, combiner))
|
||||
|
||||
if (enqueue_data.sample_indices is not None and
|
||||
enqueue_data.sample_indices.op.device !=
|
||||
enqueue_data.embedding_indices.op.device):
|
||||
raise ValueError(
|
||||
'Device of sample_indices does not agree with '
|
||||
'that of emebdding_indices for feature {}.'.format(feature))
|
||||
if (enqueue_data.aggregation_weights is not None and
|
||||
enqueue_data.aggregation_weights.op.device !=
|
||||
enqueue_data.embedding_indices.op.device):
|
||||
raise ValueError(
|
||||
'Device of aggregation_weights does not agree with '
|
||||
'that of emebdding_indices for feature {}.'.format(feature))
|
||||
# Check all features are on the same device.
|
||||
if device is None:
|
||||
device = tensor.op.device
|
||||
device = enqueue_data.embedding_indices.op.device
|
||||
device_feature = feature
|
||||
else:
|
||||
if device != tensor.op.device:
|
||||
if device != enqueue_data.embedding_indices.op.device:
|
||||
raise ValueError('Devices are different between features in '
|
||||
'`sparse_features_list[{}]`; '
|
||||
'`enqueue_datas_list[{}]`; '
|
||||
'devices: {}, {}; features: {}, {}.'.format(
|
||||
i, device, tensor.op.device, feature,
|
||||
device_feature))
|
||||
i, device,
|
||||
enqueue_data.embedding_indices.op.device,
|
||||
feature, device_feature))
|
||||
|
||||
if i % self._num_cores_per_host:
|
||||
if device != contiguous_device:
|
||||
raise ValueError('We expect the `sparse_features` which are on the '
|
||||
raise ValueError('We expect the `enqueue_datas` which are on the '
|
||||
'same host to be contiguous in '
|
||||
'`sparse_features_list`, '
|
||||
'`sparse_features_list[{}]` is on device {}, '
|
||||
'`enqueue_datas_list`, '
|
||||
'`enqueue_datas_list[{}]` is on device {}, '
|
||||
'but is expected to be on device {}.'.format(
|
||||
i, device, contiguous_device))
|
||||
else:
|
||||
contiguous_device = device
|
||||
|
||||
def _generate_enqueue_op(self, sparse_features, device_ordinal):
|
||||
with ops.colocate_with(list(sparse_features.values())[0]):
|
||||
sample_idcs, embedding_idcs, aggregation_weights, table_ids = (
|
||||
self._format_for_tpu_embedding_sparse_tensor_batch(sparse_features))
|
||||
def _generate_enqueue_op(self, enqueue_datas, device_ordinal):
|
||||
enqueue_data0 = list(enqueue_datas.values())[0]
|
||||
with ops.colocate_with(enqueue_data0.embedding_indices):
|
||||
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
|
||||
table_ids) = (
|
||||
self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
|
||||
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
|
||||
sample_idcs,
|
||||
embedding_idcs,
|
||||
aggregation_weights,
|
||||
sample_indices_list,
|
||||
embedding_indices_list,
|
||||
aggregation_weights_list,
|
||||
table_ids,
|
||||
device_ordinal=device_ordinal,
|
||||
combiners=self._combiners)
|
||||
|
||||
def _format_for_tpu_embedding_sparse_tensor_batch(self, sparse_features):
|
||||
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
|
||||
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
||||
|
||||
Args:
|
||||
sparse_features: a `Dict` of tensors for embedding. Can be sparse or
|
||||
enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
|
||||
dense.
|
||||
|
||||
Returns:
|
||||
Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
|
||||
"""
|
||||
|
||||
sample_idcs, embedding_idcs, aggregation_weights, table_ids = [], [], [], []
|
||||
(sample_indices_list, embedding_indices_list, aggregation_weights_list,
|
||||
table_ids) = [], [], [], []
|
||||
for table_id, table in enumerate(self._table_to_features_dict):
|
||||
features = self._table_to_features_dict[table]
|
||||
for feature in features:
|
||||
tensor = sparse_features[feature]
|
||||
if not isinstance(tensor, sparse_tensor.SparseTensor):
|
||||
sample_idcs.append(array_ops.zeros([0], dtype=dtypes.int32))
|
||||
embedding_idcs.append(tensor)
|
||||
else:
|
||||
sample_idcs.append(tensor.indices)
|
||||
embedding_idcs.append(tensor.values)
|
||||
aggregation_weights.append(array_ops.zeros([0]))
|
||||
enqueue_data = enqueue_datas[feature]
|
||||
|
||||
sample_indices = (
|
||||
enqueue_data.sample_indices
|
||||
if enqueue_data.sample_indices is not None else array_ops.zeros(
|
||||
(0,), dtype=dtypes.int32))
|
||||
sample_indices_list.append(sample_indices)
|
||||
|
||||
aggregation_weights = (
|
||||
enqueue_data.aggregation_weights if
|
||||
enqueue_data.aggregation_weights is not None else array_ops.zeros(
|
||||
(0,), dtype=dtypes.float32))
|
||||
aggregation_weights_list.append(aggregation_weights)
|
||||
|
||||
embedding_indices_list.append(enqueue_data.embedding_indices)
|
||||
|
||||
table_ids.append(table_id)
|
||||
|
||||
return sample_idcs, embedding_idcs, aggregation_weights, table_ids
|
||||
return (sample_indices_list, embedding_indices_list,
|
||||
aggregation_weights_list, table_ids)
|
||||
|
||||
def get_activations(self):
|
||||
"""Get activations for features.
|
||||
|
@ -891,7 +891,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
|
||||
"""Generates the per_host enqueue ops."""
|
||||
control_deps = []
|
||||
per_host_sharded_inputs = []
|
||||
sparse_features_list = []
|
||||
enqueue_datas_list = []
|
||||
num_replicas_per_host = ctx.num_of_replicas_per_host
|
||||
cached_signals = None
|
||||
with ops.device(device):
|
||||
@ -910,9 +910,9 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
|
||||
else:
|
||||
cached_signals = signals
|
||||
|
||||
features, labels, sparse_features = (
|
||||
features, labels, enqueue_data = (
|
||||
_tpu_estimator_embedding.split_inputs(ctx, features, labels))
|
||||
sparse_features_list.append(sparse_features)
|
||||
enqueue_datas_list.append(enqueue_data)
|
||||
|
||||
inputs_structure_recorder.validate_and_record_structure(
|
||||
features, labels)
|
||||
@ -945,7 +945,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
|
||||
if ctx.embedding_config:
|
||||
per_host_enqueue_ops.extend(
|
||||
ctx.embedding_config.tpu_embedding.generate_enqueue_ops(
|
||||
sparse_features_list))
|
||||
enqueue_datas_list))
|
||||
|
||||
if signals is None:
|
||||
return per_host_enqueue_ops
|
||||
|
Loading…
Reference in New Issue
Block a user