diff --git a/tensorflow/python/tpu/_tpu_estimator_embedding.py b/tensorflow/python/tpu/_tpu_estimator_embedding.py index 67eb2c3ed52..4a832dbbe3a 100644 --- a/tensorflow/python/tpu/_tpu_estimator_embedding.py +++ b/tensorflow/python/tpu/_tpu_estimator_embedding.py @@ -115,7 +115,7 @@ def get_configs_from_feature_columns(feature_columns): Returns: A tuple of dicts, the first maps tables to their config, the second maps - features to tables, and the third maps features to weight key names. + features to their config, and the third maps features to weight key names. """ allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access @@ -127,17 +127,18 @@ def get_configs_from_feature_columns(feature_columns): type(column), allowed)) table_to_config = {} - feature_to_table = {} + feature_to_config = {} feature_to_weight_key_name = {} for column in feature_columns: feature_name = column.get_feature_key_name() table_name = _get_table_name_from_embedding_var_name( column.get_embedding_var_name()) - if feature_name in feature_to_table: + if feature_name in feature_to_config: raise ValueError( 'Feature column {} is used with multiple embeddings and this is ' 'not supported.'.format(feature_name)) - feature_to_table[feature_name] = table_name + feature_to_config[feature_name] = tpu_embedding.FeatureConfig( + table_id=table_name) feature_to_weight_key_name[feature_name] = column.get_weight_key_name() vocabulary_size, dimension = column.get_embedding_table_size() table_to_config[table_name] = tpu_embedding.TableConfig( @@ -146,7 +147,7 @@ def get_configs_from_feature_columns(feature_columns): initializer=column.get_initializer(), combiner=column.get_combiner()) - return table_to_config, feature_to_table, feature_to_weight_key_name + return table_to_config, feature_to_config, feature_to_weight_key_name class EmbeddingConfigSpec( @@ -238,7 +239,7 @@ class EmbeddingConfig(object): self._num_cores = num_cores self._run_config = run_config - (self._table_to_config_dict, self._feature_to_table_dict, + (self._table_to_config_dict, self._feature_to_config_dict, self.feature_to_weight_key_name_dict) = ( get_configs_from_feature_columns( embedding_config_spec.feature_columns)) @@ -286,7 +287,7 @@ class EmbeddingConfig(object): cluster_def = None tpu_embedding_ = tpu_embedding.TPUEmbedding( self._table_to_config_dict, - self._feature_to_table_dict, + self._feature_to_config_dict, batch_size, tpu_embedding_mode, master, @@ -310,7 +311,7 @@ def split_inputs(ctx, features, labels): tpu_embedding_ = ctx.embedding_config.tpu_embedding feature_to_weight_key_name_dict = ( ctx.embedding_config.feature_to_weight_key_name_dict) - for feature_key in tpu_embedding_.feature_to_table_dict: + for feature_key in tpu_embedding_.feature_to_config_dict: sparse_feature = _get_sparse_feature_from_feature(feature_key, features) weight_key_name = feature_to_weight_key_name_dict[feature_key] if isinstance(sparse_feature, sparse_tensor.SparseTensor): diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index ba11171ffba..95287123b90 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -96,6 +96,37 @@ class TableConfig( initializer, combiner) +class FeatureConfig( + collections.namedtuple( + 'FeatureConfig', + ['table_id', 'max_sequence_length'])): + """Feature configuration.""" + + def __new__(cls, + table_id, + max_sequence_length=0): + """Feature configuration. + + Args: + table_id: Which table the feature is uses for embedding lookups. + max_sequence_length: If positive, the feature is a sequence feature with + the corresponding maximum sequence length. If the sequence is longer + than this, it will be truncated. If 0, the feature is not a sequence + feature. + + Returns: + `FeatureConfig`. + + Raises: + ValueError: if `max_sequence_length` non-negative. + """ + if not isinstance(max_sequence_length, int) or max_sequence_length < 0: + raise ValueError('Invalid max_sequence_length {}.'.format( + max_sequence_length)) + + return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length) + + class EnqueueData( collections.namedtuple( 'EnqueueData', @@ -281,15 +312,15 @@ class TPUEmbedding(object): initializer=initializer, combiner='mean') table_to_config_dict = {'video': table_config_video, 'user': table_config_user} - feature_to_table_dict = {'watched': 'video', - 'favorited': 'video', - 'friends': 'user'} + feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'), + 'favorited': tpu_embedding.FeatureConfig('video'), + 'friends': tpu_embedding.FeatureConfig('user')} batch_size = 4 num_hosts = 1 optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) mode = tpu_embedding.TRAINING embedding = tpu_embedding.TPUEmbedding( - table_to_config_dict, feature_to_table_dict, + table_to_config_dict, feature_to_config_dict, batch_size, num_hosts, mode, optimization_parameters) batch_size_per_core = embedding.batch_size_per_core @@ -336,17 +367,16 @@ class TPUEmbedding(object): ``` """ - # TODO(shizhiw): Instead of `feature_to_table_dict` which maps to table - # name, consider `feature_to_config_dict` which maps to `FeatureConfig`. - # `FeatureConfig` could have fields other than table name. For example, it - # could have a field to indicate that the feature should not be used to - # update embedding table (cr/204852758, cr/204940540). Also, this can support - # different combiners for different features within the same table. + # TODO(shizhiw): Consider addign a field to FeatureConfig that indicates that + # the feature should not be used to update embedding table (cr/204852758, + # cr/204940540). Also, this can support different combiners for different + # features within the same table. # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it # to `FeatureConfig`? # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and - # `feature_to_table_dict` lists of `TableSpec` and `FeatureSpec` respectively? + # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec` + # respectively? # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate # for-loops around construction of inputs. @@ -356,7 +386,7 @@ class TPUEmbedding(object): # global setting. def __init__(self, table_to_config_dict, - feature_to_table_dict, + feature_to_config_dict, batch_size, mode, master, @@ -369,9 +399,9 @@ class TPUEmbedding(object): table_to_config_dict: A dictionary mapping from string of table name to `TableConfig`. Table refers to an embedding table, e.g. `params` argument to `tf.nn.embedding_lookup_sparse()`. - feature_to_table_dict: A dictionary mapping from string of feature name - to string of table name. Feature refers to ids to lookup in embedding - table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. + feature_to_config_dict: A dictionary mapping from string of feature name + to `FeatureConfig`. Feature refers to ids to lookup in embedding table, + e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. batch_size: An `int` representing the global batch size. mode: `TRAINING` or `INFERENCE`. master: A `string` representing the TensorFlow master to use. @@ -391,10 +421,12 @@ class TPUEmbedding(object): # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) - _validate_feature_to_table_dict(table_to_config_dict, feature_to_table_dict) - self._feature_to_table_dict = _create_ordered_dict(feature_to_table_dict) - self._table_to_features_dict = _create_table_to_features_dict( - self._feature_to_table_dict) + _validate_feature_to_config_dict(table_to_config_dict, + feature_to_config_dict) + self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict) + self._table_to_features_dict, self._table_to_num_features_dict = ( + _create_table_to_features_and_num_features_dicts( + self._feature_to_config_dict)) self._combiners = _create_combiners(self._table_to_config_dict, self._table_to_features_dict) @@ -504,8 +536,8 @@ class TPUEmbedding(object): return copy.copy(self._table_to_config_dict) @property - def feature_to_table_dict(self): - return copy.copy(self._feature_to_table_dict) + def feature_to_config_dict(self): + return copy.copy(self._feature_to_config_dict) @property def table_to_features_dict(self): @@ -526,8 +558,7 @@ class TPUEmbedding(object): table_descriptor.vocabulary_size = table_config.vocabulary_size table_descriptor.dimension = table_config.dimension - features_for_table = self._table_to_features_dict[table] - table_descriptor.num_features = len(features_for_table) + table_descriptor.num_features = self._table_to_num_features_dict[table] table_descriptor.optimization_parameters.learning_rate.constant = ( self._optimization_parameters.learning_rate) @@ -652,7 +683,7 @@ class TPUEmbedding(object): def _validate_generate_enqueue_ops_enqueue_datas_list(self, enqueue_datas_list): """Validate `enqueue_datas_list`.""" - feature_set = set(self._feature_to_table_dict.keys()) + feature_set = set(self._feature_to_config_dict.keys()) contiguous_device = None for i, enqueue_datas in enumerate(enqueue_datas_list): used_feature_set = set(enqueue_datas.keys()) @@ -674,7 +705,7 @@ class TPUEmbedding(object): device_feature = None for feature, enqueue_data in six.iteritems(enqueue_datas): combiner = self._table_to_config_dict[ - self._feature_to_table_dict[feature]].combiner + self._feature_to_config_dict[feature].table_id].combiner if not isinstance(enqueue_data, EnqueueData): raise ValueError('`enqueue_datas_list[{}]` has a feature that is ' 'not mapped to `EnqueueData`. `feature`: {}'.format( @@ -726,7 +757,7 @@ class TPUEmbedding(object): enqueue_data0 = list(enqueue_datas.values())[0] with ops.colocate_with(enqueue_data0.embedding_indices): (sample_indices_list, embedding_indices_list, aggregation_weights_list, - table_ids) = ( + table_ids, max_sequence_lengths) = ( self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)) return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( sample_indices_list, @@ -734,7 +765,8 @@ class TPUEmbedding(object): aggregation_weights_list, table_ids, device_ordinal=device_ordinal, - combiners=self._combiners) + combiners=self._combiners, + max_sequence_lengths=max_sequence_lengths) def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas): """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`. @@ -748,7 +780,7 @@ class TPUEmbedding(object): """ (sample_indices_list, embedding_indices_list, aggregation_weights_list, - table_ids) = [], [], [], [] + table_ids, max_sequence_lengths) = [], [], [], [], [] for table_id, table in enumerate(self._table_to_features_dict): features = self._table_to_features_dict[table] for feature in features: @@ -769,9 +801,11 @@ class TPUEmbedding(object): embedding_indices_list.append(enqueue_data.embedding_indices) table_ids.append(table_id) + max_sequence_lengths.append( + self._feature_to_config_dict[feature].max_sequence_length) return (sample_indices_list, embedding_indices_list, - aggregation_weights_list, table_ids) + aggregation_weights_list, table_ids, max_sequence_lengths) def get_activations(self): """Get activations for features. @@ -790,9 +824,21 @@ class TPUEmbedding(object): activations = collections.OrderedDict() for table_id, table in enumerate(self._table_to_features_dict): features = self._table_to_features_dict[table] - for lookup_id, feature in enumerate(features): - stride = len(self._table_to_features_dict[table]) - activations[feature] = recv_activations[table_id][lookup_id::stride, :] + num_features = self._table_to_num_features_dict[table] + feature_index = 0 + table_activations = array_ops.reshape( + recv_activations[table_id], + [self.batch_size_per_core, num_features, -1]) + for feature in features: + seq_length = self._feature_to_config_dict[feature].max_sequence_length + if not seq_length: + activations[feature] = table_activations[:, feature_index, :] + feature_index = feature_index + 1 + else: + activations[feature] = ( + table_activations[:, feature_index:(feature_index+seq_length), :]) + feature_index = feature_index + seq_length + return activations def generate_send_gradients_op(self, feature_to_gradient_dict): @@ -815,12 +861,16 @@ class TPUEmbedding(object): gradients = [] for table in self._table_to_features_dict: features = self._table_to_features_dict[table] - table_gradients = [ - feature_to_gradient_dict[feature] for feature in features - ] + table_gradients = [] + for feature in features: + gradient = feature_to_gradient_dict[feature] + # Expand dims for non-sequence feature to match sequence features. + if gradient.shape.ndims == 2: + gradient = array_ops.expand_dims(gradient, 1) + table_gradients.append(gradient) interleaved_table_grads = array_ops.reshape( - array_ops.stack(table_gradients, axis=1), - [-1, table_gradients[0].shape[1]]) + array_ops.concat(table_gradients, axis=1), + [-1, table_gradients[0].shape[-1]]) gradients.append(interleaved_table_grads) return tpu_ops.send_tpu_embedding_gradients( inputs=gradients, config=self.config_proto.SerializeToString()) @@ -834,21 +884,22 @@ def _validate_table_to_config_dict(table_to_config_dict): '`TableConfig`, got {} for {}.'.format(type(v), k)) -def _validate_feature_to_table_dict(table_to_config_dict, - feature_to_table_dict): - """Validate `feature_to_table_dict`.""" - used_table_set = set(feature_to_table_dict.values()) +def _validate_feature_to_config_dict(table_to_config_dict, + feature_to_config_dict): + """Validate `feature_to_config_dict`.""" + used_table_set = set([feature.table_id + for feature in feature_to_config_dict.values()]) table_set = set(table_to_config_dict.keys()) unused_table_set = table_set - used_table_set if unused_table_set: raise ValueError('`table_to_config_dict` specifies table that is not ' - 'used in `feature_to_table_dict`: {}.' + 'used in `feature_to_config_dict`: {}.' .format(unused_table_set)) extra_table_set = used_table_set - table_set if extra_table_set: - raise ValueError('`feature_to_table_dict` refers to a table that is not ' + raise ValueError('`feature_to_config_dict` refers to a table that is not ' 'specified in `table_to_config_dict`: {}.' .format(extra_table_set)) @@ -1135,19 +1186,30 @@ def _create_combiners(table_to_config_dict, table_to_features_dict): return combiners -def _create_table_to_features_dict(feature_to_table_dict): +def _create_table_to_features_and_num_features_dicts(feature_to_config_dict): """Create mapping from table to a list of its features.""" table_to_features_dict_tmp = {} - for feature, table in six.iteritems(feature_to_table_dict): - if table in table_to_features_dict_tmp: - table_to_features_dict_tmp[table].append(feature) + table_to_num_features_dict_tmp = {} + for feature, feature_config in six.iteritems(feature_to_config_dict): + if feature_config.table_id in table_to_features_dict_tmp: + table_to_features_dict_tmp[feature_config.table_id].append(feature) else: - table_to_features_dict_tmp[table] = [feature] + table_to_features_dict_tmp[feature_config.table_id] = [feature] + table_to_num_features_dict_tmp[feature_config.table_id] = 0 + if feature_config.max_sequence_length == 0: + table_to_num_features_dict_tmp[feature_config.table_id] = ( + table_to_num_features_dict_tmp[feature_config.table_id] + 1) + else: + table_to_num_features_dict_tmp[feature_config.table_id] = ( + table_to_num_features_dict_tmp[feature_config.table_id] + + feature_config.max_sequence_length) table_to_features_dict = collections.OrderedDict() + table_to_num_features_dict = collections.OrderedDict() for table in sorted(table_to_features_dict_tmp): table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) - return table_to_features_dict + table_to_num_features_dict[table] = table_to_num_features_dict_tmp[table] + return table_to_features_dict, table_to_num_features_dict def _create_device_fn(hosts): diff --git a/tensorflow/python/tpu/tpu_embedding_gradient.py b/tensorflow/python/tpu/tpu_embedding_gradient.py index 7437756601e..657fa1068ca 100644 --- a/tensorflow/python/tpu/tpu_embedding_gradient.py +++ b/tensorflow/python/tpu/tpu_embedding_gradient.py @@ -115,7 +115,7 @@ def hook_dummy_table_variables_to_activations(tpu_embedding, activations, """ new_activations = collections.OrderedDict() for feature in activations: - table = tpu_embedding.feature_to_table_dict[feature] + table = tpu_embedding.feature_to_config_dict[feature].table_id new_activations[feature] = tpu_ops.tpu_embedding_activations( dummy_table_variables[table], activations[feature],