From b220af894a67d875b8059cfa78f0cb28fa72f834 Mon Sep 17 00:00:00 2001 From: Revan Sopher <rsopher@google.com> Date: Wed, 1 Apr 2020 13:23:59 -0700 Subject: [PATCH] Add EnqueueTPUEmbeddingRaggedTensorBatch for RaggedTensor support. PiperOrigin-RevId: 304250071 Change-Id: If1f0d7a8716c95a090f28d085a46ffa9c3e9053e --- ...EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt | 77 +++++++ tensorflow/core/ops/ops.pbtxt | 95 ++++++++ tensorflow/core/ops/tpu_embedding_ops.cc | 16 ++ tensorflow/python/tpu/ops/tpu_ops.py | 73 +++++++ tensorflow/python/tpu/tpu_embedding.py | 202 +++++++++++++++--- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 7 files changed, 441 insertions(+), 30 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt new file mode 100644 index 00000000000..cdcdd6d06b0 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt @@ -0,0 +1,77 @@ +op { + graph_op_name: "EnqueueTPUEmbeddingRaggedTensorBatch" + visibility: HIDDEN + in_arg { + name: "sample_splits" + description: <<END +A list of rank 1 Tensors specifying the break points for splitting +embedding_indices and aggregation_weights into rows. +It corresponds to ids.row_splits in embedding_lookup(), when ids is a +RaggedTensor. +END + } + in_arg { + name: "embedding_indices" + description: <<END +A list of rank 1 Tensors, indices into the embedding tables. +It corresponds to ids.values in embedding_lookup(), when ids is a RaggedTensor. +END + } + in_arg { + name: "aggregation_weights" + description: <<END +A list of rank 1 Tensors containing per training example +aggregation weights. It corresponds to the values field of a RaggedTensor +with the same row_splits as ids in embedding_lookup(), when ids is a +RaggedTensor. +END + } + in_arg { + name: "mode_override" + description: <<END +A string input that overrides the mode specified in the +TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', +'training', 'backward_pass_only'}. When set to 'unspecified', the mode set +in TPUEmbeddingConfiguration is used, otherwise mode_override is used. +END + } + attr { + name: "device_ordinal" + description: <<END +The TPU device to use. Should be >= 0 and less than the number +of TPU cores in the task on which the node is placed. +END + } + attr { + name: "combiners" + description: <<END +A list of string scalars, one for each embedding table that specify +how to normalize the embedding activations after weighted summation. +Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have +the sum of the weights be 0 for 'mean' or the sum of the squared weights be +0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for +all tables. +END + } + attr { + name: "table_ids" + description: <<END +A list of integers specifying the identifier of the embedding table +(offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the +corresponding input. The ith input is looked up using table_ids[i]. The size +of the table_ids list must be equal to that of sample_indices, +embedding_indices and aggregation_weights. +END + } + summary: "Eases the porting of code that uses tf.nn.embedding_lookup()." + description: <<END +sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond +to the ith feature. table_ids[i] indicates which embedding table to look up ith +feature. + +The tensors at corresponding positions in two of the input lists, +embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1 +with dim_size() equal to the total number of lookups into the table described by +the corresponding feature. +END +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index c4d9daeeb47..91750889b95 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -13184,6 +13184,101 @@ op { } is_stateful: true } +op { + name: "EnqueueTPUEmbeddingRaggedTensorBatch" + input_arg { + name: "sample_indices" + type_attr: "T1" + number_attr: "N" + } + input_arg { + name: "embedding_indices" + type_attr: "T2" + number_attr: "N" + } + input_arg { + name: "aggregation_weights" + type_attr: "T3" + number_attr: "N" + } + input_arg { + name: "mode_override" + type: DT_STRING + } + attr { + name: "T1" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T2" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T3" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "device_ordinal" + type: "int" + default_value { + i: -1 + } + } + attr { + name: "combiners" + type: "list(string)" + default_value { + list { + } + } + } + attr { + name: "table_ids" + type: "list(int)" + } + attr { + name: "max_sequence_lengths" + type: "list(int)" + default_value { + list { + } + } + } + is_stateful: true +} op { name: "EnsureShape" input_arg { diff --git a/tensorflow/core/ops/tpu_embedding_ops.cc b/tensorflow/core/ops/tpu_embedding_ops.cc index 821dff7c64a..164d78e8e9e 100644 --- a/tensorflow/core/ops/tpu_embedding_ops.cc +++ b/tensorflow/core/ops/tpu_embedding_ops.cc @@ -168,4 +168,20 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch") + .Input("sample_splits: N * T1") + .Input("embedding_indices: N * T2") + .Input("aggregation_weights: N * T3") + .Input("mode_override: string") + .Attr("T1: {int32,int64} = DT_INT32") + .Attr("T2: {int32,int64} = DT_INT32") + .Attr("T3: {float32,float64} = DT_FLOAT") + .Attr("N: int >= 1") + .Attr("device_ordinal: int = -1") + .Attr("combiners: list(string) = []") + .Attr("table_ids: list(int)") + .Attr("max_sequence_lengths: list(int) = []") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape); + } // namespace tensorflow diff --git a/tensorflow/python/tpu/ops/tpu_ops.py b/tensorflow/python/tpu/ops/tpu_ops.py index c1ea3641757..8facb1fdad7 100644 --- a/tensorflow/python/tpu/ops/tpu_ops.py +++ b/tensorflow/python/tpu/ops/tpu_ops.py @@ -444,3 +444,76 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) + + +# pylint: disable=protected-access +def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits, + embedding_indices, + aggregation_weights, + table_ids, + device_ordinal, + max_sequence_lengths=None, + combiners=None, + mode_override=None, + name=None): + """A placeholder op for enqueueing embedding IDs to the TPU. + + Args: + sample_splits: A list of rank 1 Tensors specifying the break points for + splitting embedding_indices and aggregation_weights into rows. It + corresponds to ids.row_splits in embedding_lookup(), when ids is a + RaggedTensor. Both int32 and int64 are allowed and will be converted to + int32 internally. + embedding_indices: A list of rank 1 Tensors, indices into the embedding + tables. It corresponds to ids.values in embedding_lookup(), when ids is a + RaggedTensor. Both int32 and int64 are allowed and will be converted to + int32 internally. + aggregation_weights: A list of rank 1 Tensors containing per training + example aggregation weights. It corresponds to the values field of a + RaggedTensor with the same row_splits as ids in embedding_lookup(), when + ids is a RaggedTensor. Both float32 and float64 are allowed and will be + converted to float32 internally. + table_ids: A list of integers specifying the identifier of the embedding + table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to + lookup the corresponding input. The ith input is looked up using + table_ids[i]. The size of the table_ids list must be equal to that of + sample_indices, embedding_indices and aggregation_weights. + device_ordinal: The TPU device to use. Should be >= 0 and less than the + number of TPU cores in the task on which the node is placed. + max_sequence_lengths: A list of integers, the size of which is equal to + sample_indices. If equal to 0, the corresponding feature is considered to + be a non-sequence feature, If greater than 0, the corresponding feature is + a sequence feature with the given maximal length. If None, then we assume + a list of all zeroes. + combiners: A list of string scalars, one for each embedding table that + specify how to normalize the embedding activations after weighted + summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is + invalid to have the sum of the weights be 0 for 'mean' or the sum of the + squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default + is to use 'sum' for all tables (optional). + mode_override: A string input that overrides the mode specified in the + TPUEmbeddingConfiguration. Supported values are {'unspecified', + 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified', + the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override + is used (optional). + name: A name for the operation (optional). + + Returns: + An EnqueueTPUEmbeddingRaggedTensorBatch operation. + """ + if mode_override is None: + mode_override = "unspecified" + return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( + sample_splits=sample_splits, + embedding_indices=embedding_indices, + aggregation_weights=aggregation_weights, + table_ids=table_ids, + device_ordinal=device_ordinal, + max_sequence_lengths=max_sequence_lengths, + combiners=combiners, + mode_override=mode_override, + name=name) + + +enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = ( + gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__) diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index e3dbe7fb93f..e24188eaf16 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -205,6 +205,48 @@ class EnqueueData( aggregation_weights=weights.values if weights is not None else None) +class RaggedEnqueueData( + collections.namedtuple( + 'RaggedEnqueueData', + ['embedding_indices', 'sample_splits', 'aggregation_weights'])): + """RaggedTensor Data to be enqueued through generate_enqueue_ops().""" + + def __new__(cls, + embedding_indices, + sample_splits=None, + aggregation_weights=None): + """Data to be enqueued through generate_enqueue_ops(). + + Args: + embedding_indices: A rank 1 Tensor, indices into the embedding tables. It + corresponds to ids.values in embedding_lookup(), when ids is a + RaggedTensor. Both int32 and int64 are allowed and will be converted to + int32 internally. + sample_splits: A rank 1 Tensor specifying the break points for splitting + embedding_indices and aggregation_weights into rows. It corresponds to + ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both + int32 and int64 are allowed and will be converted to int32 internally. + aggregation_weights: A rank 1 Tensor containing per training example + aggregation weights. It corresponds to the values field of a + RaggedTensor with the same row_splits as ids in embedding_lookup(), when + ids is a RaggedTensor. + + Returns: + An RaggedEnqueueData tuple. + + """ + return super(RaggedEnqueueData, + cls).__new__(cls, embedding_indices, sample_splits, + aggregation_weights) + + @staticmethod + def from_ragged_tensor(rg_tensor, weights=None): + return RaggedEnqueueData( + rg_tensor.values, + rg_tensor.row_splits, + aggregation_weights=weights.values if weights is not None else None) + + def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list): """Convenient function for generate_enqueue_ops(). @@ -229,6 +271,30 @@ def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list): return enqueue_datas_list +def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list): + """Convenient function for generate_enqueue_ops(). + + Args: + rg_tensors_list: a list of dictionary mapping from string of feature names + to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the + same host should be contiguous on the list. + + Returns: + enqueue_datas_list: a list of dictionary mapping from string + of feature names to RaggedEnqueueData. Each dictionary is for one + TPU core. Dictionaries for the same host should be contiguous + on the list. + + """ + enqueue_datas_list = [] + for rg_tensors in rg_tensors_list: + enqueue_datas = collections.OrderedDict( + (k, RaggedEnqueueData.from_ragged_tensor(v)) + for k, v in six.iteritems(rg_tensors)) + enqueue_datas_list.append(enqueue_datas) + return enqueue_datas_list + + AdamSlotVariableNames = collections.namedtuple( 'AdamSlotVariableNames', ['m', 'v']) @@ -1159,7 +1225,12 @@ class TPUEmbedding(object): slot_variables_by_table, load_ops, retrieve_ops) - def generate_enqueue_ops(self, enqueue_datas_list, mode_override=None): + def generate_enqueue_ops( + self, + enqueue_datas_list, + mode_override=None, + ragged=False, + ): """Generate enqueue ops. Args: @@ -1172,6 +1243,8 @@ class TPUEmbedding(object): 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override is used (optional). + ragged: If True, creates RaggedTensor enqueue ops rather than + SparseTensor. Returns: Ops to enqueue to TPU for embedding. @@ -1182,6 +1255,7 @@ class TPUEmbedding(object): enqueue_datas, device_ordinal=i % self._num_cores_per_host, mode_override=mode_override, + ragged=ragged, ) for i, enqueue_datas in enumerate(enqueue_datas_list) ] @@ -1211,28 +1285,50 @@ class TPUEmbedding(object): for feature, enqueue_data in six.iteritems(enqueue_datas): combiner = self._table_to_config_dict[ self._feature_to_config_dict[feature].table_id].combiner - if not isinstance(enqueue_data, EnqueueData): - raise ValueError('`enqueue_datas_list[{}]` has a feature that is ' - 'not mapped to `EnqueueData`. `feature`: {}'.format( - i, feature)) - if enqueue_data.sample_indices is None and combiner: - logging.warn('No sample indices set for features %f table %f but ' - 'combiner is set to %s.', feature, - self._feature_to_config_dict[feature].table_id, combiner) + if isinstance(enqueue_data, EnqueueData): + if enqueue_data.sample_indices is None and combiner: + logging.warn( + 'No sample indices set for features %f table %f but ' + 'combiner is set to %s.', feature, + self._feature_to_config_dict[feature].table_id, combiner) + if (enqueue_data.sample_indices is not None and + enqueue_data.sample_indices.device != + enqueue_data.embedding_indices.device): + raise ValueError( + 'Device of sample_indices does not agree with ' + 'that of embedding_indices for feature {}.'.format(feature)) + if (enqueue_data.aggregation_weights is not None and + enqueue_data.aggregation_weights.device != + enqueue_data.embedding_indices.device): + raise ValueError( + 'Device of aggregation_weights does not agree with ' + 'that of embedding_indices for feature {}.'.format(feature)) - if (enqueue_data.sample_indices is not None and - enqueue_data.sample_indices.device != - enqueue_data.embedding_indices.device): + elif isinstance(enqueue_data, RaggedEnqueueData): + if enqueue_data.sample_splits is None and combiner: + logging.warn( + 'No sample splits set for features %f table %f but ' + 'combiner is set to %s.', feature, + self._feature_to_config_dict[feature].table_id, combiner) + if (enqueue_data.sample_splits is not None and + enqueue_data.sample_splits.device != + enqueue_data.embedding_indices.device): + raise ValueError( + 'Device of sample_splits does not agree with ' + 'that of embedding_indices for feature {}.'.format(feature)) + if (enqueue_data.aggregation_weights is not None and + enqueue_data.aggregation_weights.device != + enqueue_data.embedding_indices.device): + raise ValueError( + 'Device of aggregation_weights does not agree with ' + 'that of embedding_indices for feature {}.'.format(feature)) + + else: raise ValueError( - 'Device of sample_indices does not agree with ' - 'that of embedding_indices for feature {}.'.format(feature)) - if (enqueue_data.aggregation_weights is not None and - enqueue_data.aggregation_weights.device != - enqueue_data.embedding_indices.device): - raise ValueError( - 'Device of aggregation_weights does not agree with ' - 'that of embedding_indices for feature {}.'.format(feature)) + '`enqueue_datas_list[{}]` has a feature that is not mapped to ' + '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format( + i, feature)) # Check all features are on the same device. if device is None: device = enqueue_data.embedding_indices.device @@ -1257,23 +1353,69 @@ class TPUEmbedding(object): else: contiguous_device = device - def _generate_enqueue_op( - self, enqueue_datas, device_ordinal, mode_override=None): + def _generate_enqueue_op(self, + enqueue_datas, + device_ordinal, + mode_override=None, + ragged=False): + """Creates op for enqueuing batch to TPU.""" enqueue_data0 = list(enqueue_datas.values())[0] with ops.colocate_with(enqueue_data0.embedding_indices): - return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( - device_ordinal=device_ordinal, - combiners=self._combiners, - mode_override=mode_override, - **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas) - ) + if ragged: + # note that this is currently identical in behavior + return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( + device_ordinal=device_ordinal, + combiners=self._combiners, + mode_override=mode_override, + **self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas)) + else: + return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( + device_ordinal=device_ordinal, + combiners=self._combiners, + mode_override=mode_override, + **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)) + + def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas): + """Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`. + + Args: + enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding. + + Returns: + Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`. + """ + + kwargs = { + 'sample_splits': [], + 'embedding_indices': [], + 'aggregation_weights': [], + 'table_ids': [], + 'max_sequence_lengths': [], + } + for table_id, table in enumerate(self._table_to_features_dict): + features = self._table_to_features_dict[table] + for feature in features: + enqueue_data = enqueue_datas[feature] + + kwargs['sample_splits'].append(enqueue_data.sample_splits) + + kwargs['aggregation_weights'].append( + enqueue_data.aggregation_weights if enqueue_data.aggregation_weights + is not None else array_ops.zeros((0,), dtype=dtypes.float32)) + + kwargs['embedding_indices'].append(enqueue_data.embedding_indices) + + kwargs['table_ids'].append(table_id) + kwargs['max_sequence_lengths'].append( + self._feature_to_config_dict[feature].max_sequence_length) + + return kwargs def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas): """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`. Args: - enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or - dense. + enqueue_datas: a `Dict` of `EnqueueData` objects for embedding. Returns: Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 80aca6304c0..af2a47fb3b9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1244,6 +1244,10 @@ tf_module { name: "EnqueueTPUEmbeddingIntegerBatch" argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " } + member_method { + name: "EnqueueTPUEmbeddingRaggedTensorBatch" + argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], " + } member_method { name: "EnqueueTPUEmbeddingSparseBatch" argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 80aca6304c0..af2a47fb3b9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1244,6 +1244,10 @@ tf_module { name: "EnqueueTPUEmbeddingIntegerBatch" argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " } + member_method { + name: "EnqueueTPUEmbeddingRaggedTensorBatch" + argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], " + } member_method { name: "EnqueueTPUEmbeddingSparseBatch" argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "