From bb5880c426c8b226dc4b8f3a3e77b318b9d4837e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 3 Apr 2019 16:18:16 -0700 Subject: [PATCH] Support aggregation weights for embedding lookup with TPU in TPUEstimator. PiperOrigin-RevId: 241827445 --- .../python/tpu/_tpu_estimator_embedding.py | 90 ++++++--- tensorflow/python/tpu/tpu_embedding.py | 189 +++++++++++++----- tensorflow/python/tpu/tpu_estimator.py | 8 +- 3 files changed, 212 insertions(+), 75 deletions(-) diff --git a/tensorflow/python/tpu/_tpu_estimator_embedding.py b/tensorflow/python/tpu/_tpu_estimator_embedding.py index a7002c958e8..67eb2c3ed52 100644 --- a/tensorflow/python/tpu/_tpu_estimator_embedding.py +++ b/tensorflow/python/tpu/_tpu_estimator_embedding.py @@ -20,13 +20,13 @@ from __future__ import print_function import collections -import six - from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.feature_column import feature_column as core_fc from tensorflow.python.feature_column import feature_column_lib as core_fc_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops from tensorflow.python.tpu import feature_column as tpu_fc from tensorflow.python.tpu import tpu_embedding from tensorflow.python.tpu.tpu_embedding import AdagradParameters @@ -107,18 +107,15 @@ def get_full_variable_names( return embedding_variable_name_by_table, slot_variable_names_by_table -def get_tpu_embedding_config_from_feature_columns(feature_columns): - """Create configs for TPUEmbedding from a list of feature columns. - - This function will place one embedding tensor per table and the return is - intended to be used as input to TPUEmbedding. +def get_configs_from_feature_columns(feature_columns): + """Create configs for TPUEmbedding etc from a list of feature columns. Args: feature_columns: a list of supported feature columns. Returns: - A pair of dicts, the first maps tables to their config, the second maps - features to tables. + A tuple of dicts, the first maps tables to their config, the second maps + features to tables, and the third maps features to weight key names. """ allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access @@ -131,6 +128,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns): table_to_config = {} feature_to_table = {} + feature_to_weight_key_name = {} for column in feature_columns: feature_name = column.get_feature_key_name() table_name = _get_table_name_from_embedding_var_name( @@ -140,6 +138,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns): 'Feature column {} is used with multiple embeddings and this is ' 'not supported.'.format(feature_name)) feature_to_table[feature_name] = table_name + feature_to_weight_key_name[feature_name] = column.get_weight_key_name() vocabulary_size, dimension = column.get_embedding_table_size() table_to_config[table_name] = tpu_embedding.TableConfig( vocabulary_size=vocabulary_size, @@ -147,7 +146,7 @@ def get_tpu_embedding_config_from_feature_columns(feature_columns): initializer=column.get_initializer(), combiner=column.get_combiner()) - return table_to_config, feature_to_table + return table_to_config, feature_to_table, feature_to_weight_key_name class EmbeddingConfigSpec( @@ -239,9 +238,10 @@ class EmbeddingConfig(object): self._num_cores = num_cores self._run_config = run_config - self._table_to_config_dict, self._feature_to_table_dict = ( - get_tpu_embedding_config_from_feature_columns( - embedding_config_spec.feature_columns)) + (self._table_to_config_dict, self._feature_to_table_dict, + self.feature_to_weight_key_name_dict) = ( + get_configs_from_feature_columns( + embedding_config_spec.feature_columns)) self._mode_to_tpu_embedding_dict = {} self.dummy_table_variables = None @@ -305,19 +305,61 @@ class EmbeddingConfig(object): def split_inputs(ctx, features, labels): """Splits the dense and sparse tensors inside the features and labels.""" - sparse_features = collections.OrderedDict() + enqueue_datas = collections.OrderedDict() if ctx.embedding_config: tpu_embedding_ = ctx.embedding_config.tpu_embedding + feature_to_weight_key_name_dict = ( + ctx.embedding_config.feature_to_weight_key_name_dict) for feature_key in tpu_embedding_.feature_to_table_dict: - sparse_features[feature_key] = features.pop(feature_key) + sparse_feature = _get_sparse_feature_from_feature(feature_key, features) + weight_key_name = feature_to_weight_key_name_dict[feature_key] + if isinstance(sparse_feature, sparse_tensor.SparseTensor): + weights = _get_weights_from_features(weight_key_name, features) + enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor( + sparse_feature, weights) + else: + if weight_key_name is not None: + raise ValueError( + 'Found weights {} for weighted_categorical_column, which is not' + 'compatible with sparse feature {} enqueued as dense tensor.' + .format(weight_key_name, feature_key)) + enqueue_data = tpu_embedding.EnqueueData(sparse_feature) + enqueue_datas[feature_key] = enqueue_data - for v in six.itervalues(sparse_features): - if not v.dtype.is_integer: - raise ValueError('SparseTensor with string as values are not supported. ' - 'If you are using vocabulary_file_categorical_column or ' - 'vocabulary_list_categorical_column, please call ' - 'your_column.categorical_column._transform_feature({' - 'your_column.key: features[your_column.key]}) in' - 'your input_fn() to convert string to int.') + return features, labels, enqueue_datas - return features, labels, sparse_features + +def _get_sparse_feature_from_feature(feature_key, features): + """Pop and return sparse feature.""" + sparse_feature = features.pop(feature_key) + if not sparse_feature.dtype.is_integer: + raise ValueError('SparseTensor with string as values are not supported. ' + 'If you are using vocabulary_file_categorical_column or ' + 'vocabulary_list_categorical_column, please call ' + 'your_column.categorical_column._transform_feature({{' + 'your_column.key: features[your_column.key]}}) in' + 'your input_fn() to convert string to int. ' + 'feature_key = {}.'.format(feature_key)) + return sparse_feature + + +def _get_weights_from_features(weight_key_name, features): + """Pop and return feature for weights, possibly None.""" + weights = None + if weight_key_name is not None: + if weight_key_name in features: + weights = features.pop(weight_key_name) + else: + raise ValueError( + 'Cannot find weights {} for weighted_categorical_column.' + ' Please check if the weights are present in feature dict. Also' + ' note weight-sharing among weighted_categorical_column is not ' + 'supported on TPU.'.format(weight_key_name)) + if not isinstance(weights, sparse_tensor.SparseTensor): + raise ValueError( + 'weighted_categorical_column with weight key name {} has dense ' + 'weights. Dense weights are not supported on TPU. Please use ' + 'sparse weights instead.'.format(weight_key_name)) + if weights.dtype is not dtypes.float32: + weights = math_ops.to_float(weights) + return weights diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index a9d2bc64e14..ba11171ffba 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -28,7 +28,6 @@ from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops @@ -97,6 +96,73 @@ class TableConfig( initializer, combiner) +class EnqueueData( + collections.namedtuple( + 'EnqueueData', + ['embedding_indices', 'sample_indices', 'aggregation_weights'])): + """Data to be enqueued through generate_enqueue_ops().""" + + def __new__(cls, + embedding_indices, + sample_indices=None, + aggregation_weights=None): + """Data to be enqueued through generate_enqueue_ops(). + + Args: + embedding_indices: A rank 1 Tensors, indices into the embedding tables. It + corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32 + and int64 are allowed and will be converted to int32 internally. + sample_indices: A rank 2 Tensors specifying the training example to which + the corresponding embedding_indices and aggregation_weights values + belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). + If it is None, we assume each embedding_indices belongs to a different + sample. Both int32 and int64 are allowed and will be converted to int32 + internally. + aggregation_weights: A rank 1 Tensors containing per training example + aggregation weights. It corresponds to sp_weights.values in + embedding_lookup_sparse(). If it is None, we assume all weights are 1. + Both float32 and float64 are allowed and will be converted to float32 + internally. + + Returns: + An EnqueueData tuple. + + """ + return super(EnqueueData, cls).__new__(cls, embedding_indices, + sample_indices, aggregation_weights) + + @staticmethod + def from_sparse_tensor(sp_tensor, weights=None): + return EnqueueData( + sp_tensor.values, + sp_tensor.indices, + aggregation_weights=weights.values if weights is not None else None) + + +def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list): + """Convenient function for generate_enqueue_ops(). + + Args: + sp_tensors_list: a list of dictionary mapping from string of feature names + to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the + same host should be contiguous on the list. + + Returns: + enqueue_datas_list: a list of dictionary mapping from string + of feature names to EnqueueData. Each dictionary is for one + TPU core. Dictionaries for the same host should be contiguous + on the list. + + """ + enqueue_datas_list = [] + for sp_tensors in sp_tensors_list: + enqueue_datas = collections.OrderedDict( + (k, EnqueueData.from_sparse_tensor(v)) + for k, v in six.iteritems(sp_tensors)) + enqueue_datas_list.append(enqueue_datas) + return enqueue_datas_list + + AdamSlotVariableNames = collections.namedtuple( 'AdamSlotVariableNames', ['m', 'v']) @@ -564,119 +630,148 @@ class TPUEmbedding(object): slot_variables_by_table, load_ops, retrieve_ops) - def generate_enqueue_ops(self, sparse_features_list): + def generate_enqueue_ops(self, enqueue_datas_list): """Generate enqueue ops. Args: - sparse_features_list: a list of dictionary mapping from string - of feature names to sparse tensor. Each dictionary is for one + enqueue_datas_list: a list of dictionary mapping from string + of feature names to EnqueueData. Each dictionary is for one TPU core. Dictionaries for the same host should be contiguous on the list. Returns: Ops to enqueue to TPU for embedding. """ - self._validate_generate_enqueue_ops_sparse_features_list( - sparse_features_list) + self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list) return [ self._generate_enqueue_op( - sparse_features, device_ordinal=i % self._num_cores_per_host) - for i, sparse_features in enumerate(sparse_features_list) + enqueue_datas, device_ordinal=i % self._num_cores_per_host) + for i, enqueue_datas in enumerate(enqueue_datas_list) ] - def _validate_generate_enqueue_ops_sparse_features_list( - self, sparse_features_list): - """Validate `sparse_features_list`.""" + def _validate_generate_enqueue_ops_enqueue_datas_list(self, + enqueue_datas_list): + """Validate `enqueue_datas_list`.""" feature_set = set(self._feature_to_table_dict.keys()) contiguous_device = None - for i, sparse_features in enumerate(sparse_features_list): - used_feature_set = set(sparse_features.keys()) + for i, enqueue_datas in enumerate(enqueue_datas_list): + used_feature_set = set(enqueue_datas.keys()) # Check features are valid. missing_feature_set = feature_set - used_feature_set if missing_feature_set: - raise ValueError('`sparse_features_list[{}]` misses a feature that is ' + raise ValueError('`enqueue_datas_list[{}]` misses a feature that is ' 'in `feature_to_config_dict`: {}.'.format( i, missing_feature_set)) extra_feature_set = used_feature_set - feature_set if extra_feature_set: - raise ValueError('`sparse_features_list[{}]` has a feature that is not ' + raise ValueError('`enqueue_datas_list[{}]` has a feature that is not ' 'in `feature_to_config_dict`: {}.'.format( i, extra_feature_set)) device = None device_feature = None - for feature, tensor in six.iteritems(sparse_features): + for feature, enqueue_data in six.iteritems(enqueue_datas): combiner = self._table_to_config_dict[ self._feature_to_table_dict[feature]].combiner - if not isinstance(tensor, sparse_tensor.SparseTensor) and combiner: - raise ValueError('`sparse_features_list[{}]` has a feature that is ' - 'not mapped to `SparseTensor` and has a combiner. ' - '`feature`: {}, combiner: {}'.format( + if not isinstance(enqueue_data, EnqueueData): + raise ValueError('`enqueue_datas_list[{}]` has a feature that is ' + 'not mapped to `EnqueueData`. `feature`: {}'.format( + i, feature)) + + if enqueue_data.sample_indices is None and combiner: + raise ValueError('`enqueue_datas_list[{}]` has a feature that has ' + 'neither `EnqueueData` or `combiner`.' + '`feature`: {}, combiner: {}.'.format( i, feature, combiner)) + if (enqueue_data.sample_indices is not None and + enqueue_data.sample_indices.op.device != + enqueue_data.embedding_indices.op.device): + raise ValueError( + 'Device of sample_indices does not agree with ' + 'that of emebdding_indices for feature {}.'.format(feature)) + if (enqueue_data.aggregation_weights is not None and + enqueue_data.aggregation_weights.op.device != + enqueue_data.embedding_indices.op.device): + raise ValueError( + 'Device of aggregation_weights does not agree with ' + 'that of emebdding_indices for feature {}.'.format(feature)) # Check all features are on the same device. if device is None: - device = tensor.op.device + device = enqueue_data.embedding_indices.op.device device_feature = feature else: - if device != tensor.op.device: + if device != enqueue_data.embedding_indices.op.device: raise ValueError('Devices are different between features in ' - '`sparse_features_list[{}]`; ' + '`enqueue_datas_list[{}]`; ' 'devices: {}, {}; features: {}, {}.'.format( - i, device, tensor.op.device, feature, - device_feature)) + i, device, + enqueue_data.embedding_indices.op.device, + feature, device_feature)) if i % self._num_cores_per_host: if device != contiguous_device: - raise ValueError('We expect the `sparse_features` which are on the ' + raise ValueError('We expect the `enqueue_datas` which are on the ' 'same host to be contiguous in ' - '`sparse_features_list`, ' - '`sparse_features_list[{}]` is on device {}, ' + '`enqueue_datas_list`, ' + '`enqueue_datas_list[{}]` is on device {}, ' 'but is expected to be on device {}.'.format( i, device, contiguous_device)) else: contiguous_device = device - def _generate_enqueue_op(self, sparse_features, device_ordinal): - with ops.colocate_with(list(sparse_features.values())[0]): - sample_idcs, embedding_idcs, aggregation_weights, table_ids = ( - self._format_for_tpu_embedding_sparse_tensor_batch(sparse_features)) + def _generate_enqueue_op(self, enqueue_datas, device_ordinal): + enqueue_data0 = list(enqueue_datas.values())[0] + with ops.colocate_with(enqueue_data0.embedding_indices): + (sample_indices_list, embedding_indices_list, aggregation_weights_list, + table_ids) = ( + self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)) return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( - sample_idcs, - embedding_idcs, - aggregation_weights, + sample_indices_list, + embedding_indices_list, + aggregation_weights_list, table_ids, device_ordinal=device_ordinal, combiners=self._combiners) - def _format_for_tpu_embedding_sparse_tensor_batch(self, sparse_features): + def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas): """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`. Args: - sparse_features: a `Dict` of tensors for embedding. Can be sparse or + enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or dense. Returns: Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`. """ - sample_idcs, embedding_idcs, aggregation_weights, table_ids = [], [], [], [] + (sample_indices_list, embedding_indices_list, aggregation_weights_list, + table_ids) = [], [], [], [] for table_id, table in enumerate(self._table_to_features_dict): features = self._table_to_features_dict[table] for feature in features: - tensor = sparse_features[feature] - if not isinstance(tensor, sparse_tensor.SparseTensor): - sample_idcs.append(array_ops.zeros([0], dtype=dtypes.int32)) - embedding_idcs.append(tensor) - else: - sample_idcs.append(tensor.indices) - embedding_idcs.append(tensor.values) - aggregation_weights.append(array_ops.zeros([0])) + enqueue_data = enqueue_datas[feature] + + sample_indices = ( + enqueue_data.sample_indices + if enqueue_data.sample_indices is not None else array_ops.zeros( + (0,), dtype=dtypes.int32)) + sample_indices_list.append(sample_indices) + + aggregation_weights = ( + enqueue_data.aggregation_weights if + enqueue_data.aggregation_weights is not None else array_ops.zeros( + (0,), dtype=dtypes.float32)) + aggregation_weights_list.append(aggregation_weights) + + embedding_indices_list.append(enqueue_data.embedding_indices) + table_ids.append(table_id) - return sample_idcs, embedding_idcs, aggregation_weights, table_ids + return (sample_indices_list, embedding_indices_list, + aggregation_weights_list, table_ids) def get_activations(self): """Get activations for features. diff --git a/tensorflow/python/tpu/tpu_estimator.py b/tensorflow/python/tpu/tpu_estimator.py index 43a8fbb7a19..bf084671445 100644 --- a/tensorflow/python/tpu/tpu_estimator.py +++ b/tensorflow/python/tpu/tpu_estimator.py @@ -891,7 +891,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( """Generates the per_host enqueue ops.""" control_deps = [] per_host_sharded_inputs = [] - sparse_features_list = [] + enqueue_datas_list = [] num_replicas_per_host = ctx.num_of_replicas_per_host cached_signals = None with ops.device(device): @@ -910,9 +910,9 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( else: cached_signals = signals - features, labels, sparse_features = ( + features, labels, enqueue_data = ( _tpu_estimator_embedding.split_inputs(ctx, features, labels)) - sparse_features_list.append(sparse_features) + enqueue_datas_list.append(enqueue_data) inputs_structure_recorder.validate_and_record_structure( features, labels) @@ -945,7 +945,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( if ctx.embedding_config: per_host_enqueue_ops.extend( ctx.embedding_config.tpu_embedding.generate_enqueue_ops( - sparse_features_list)) + enqueue_datas_list)) if signals is None: return per_host_enqueue_ops