Add EnqueueTPUEmbeddingRaggedTensorBatch for RaggedTensor support.

PiperOrigin-RevId: 304250071
Change-Id: If1f0d7a8716c95a090f28d085a46ffa9c3e9053e
This commit is contained in:
Revan Sopher 2020-04-01 13:23:59 -07:00 committed by TensorFlower Gardener
parent b2ce99d186
commit b220af894a
7 changed files with 441 additions and 30 deletions

View File

@ -0,0 +1,77 @@
op {
graph_op_name: "EnqueueTPUEmbeddingRaggedTensorBatch"
visibility: HIDDEN
in_arg {
name: "sample_splits"
description: <<END
A list of rank 1 Tensors specifying the break points for splitting
embedding_indices and aggregation_weights into rows.
It corresponds to ids.row_splits in embedding_lookup(), when ids is a
RaggedTensor.
END
}
in_arg {
name: "embedding_indices"
description: <<END
A list of rank 1 Tensors, indices into the embedding tables.
It corresponds to ids.values in embedding_lookup(), when ids is a RaggedTensor.
END
}
in_arg {
name: "aggregation_weights"
description: <<END
A list of rank 1 Tensors containing per training example
aggregation weights. It corresponds to the values field of a RaggedTensor
with the same row_splits as ids in embedding_lookup(), when ids is a
RaggedTensor.
END
}
in_arg {
name: "mode_override"
description: <<END
A string input that overrides the mode specified in the
TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
END
}
attr {
name: "device_ordinal"
description: <<END
The TPU device to use. Should be >= 0 and less than the number
of TPU cores in the task on which the node is placed.
END
}
attr {
name: "combiners"
description: <<END
A list of string scalars, one for each embedding table that specify
how to normalize the embedding activations after weighted summation.
Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
the sum of the weights be 0 for 'mean' or the sum of the squared weights be
0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
all tables.
END
}
attr {
name: "table_ids"
description: <<END
A list of integers specifying the identifier of the embedding table
(offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the
corresponding input. The ith input is looked up using table_ids[i]. The size
of the table_ids list must be equal to that of sample_indices,
embedding_indices and aggregation_weights.
END
}
summary: "Eases the porting of code that uses tf.nn.embedding_lookup()."
description: <<END
sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond
to the ith feature. table_ids[i] indicates which embedding table to look up ith
feature.
The tensors at corresponding positions in two of the input lists,
embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1
with dim_size() equal to the total number of lookups into the table described by
the corresponding feature.
END
}

View File

@ -13184,6 +13184,101 @@ op {
}
is_stateful: true
}
op {
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
input_arg {
name: "sample_indices"
type_attr: "T1"
number_attr: "N"
}
input_arg {
name: "embedding_indices"
type_attr: "T2"
number_attr: "N"
}
input_arg {
name: "aggregation_weights"
type_attr: "T3"
number_attr: "N"
}
input_arg {
name: "mode_override"
type: DT_STRING
}
attr {
name: "T1"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T2"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T3"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "N"
type: "int"
has_minimum: true
minimum: 1
}
attr {
name: "device_ordinal"
type: "int"
default_value {
i: -1
}
}
attr {
name: "combiners"
type: "list(string)"
default_value {
list {
}
}
}
attr {
name: "table_ids"
type: "list(int)"
}
attr {
name: "max_sequence_lengths"
type: "list(int)"
default_value {
list {
}
}
}
is_stateful: true
}
op {
name: "EnsureShape"
input_arg {

View File

@ -168,4 +168,20 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
.Input("sample_splits: N * T1")
.Input("embedding_indices: N * T2")
.Input("aggregation_weights: N * T3")
.Input("mode_override: string")
.Attr("T1: {int32,int64} = DT_INT32")
.Attr("T2: {int32,int64} = DT_INT32")
.Attr("T3: {float32,float64} = DT_FLOAT")
.Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")
.Attr("table_ids: list(int)")
.Attr("max_sequence_lengths: list(int) = []")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
} // namespace tensorflow

View File

@ -444,3 +444,76 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
# pylint: disable=protected-access
def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
embedding_indices,
aggregation_weights,
table_ids,
device_ordinal,
max_sequence_lengths=None,
combiners=None,
mode_override=None,
name=None):
"""A placeholder op for enqueueing embedding IDs to the TPU.
Args:
sample_splits: A list of rank 1 Tensors specifying the break points for
splitting embedding_indices and aggregation_weights into rows. It
corresponds to ids.row_splits in embedding_lookup(), when ids is a
RaggedTensor. Both int32 and int64 are allowed and will be converted to
int32 internally.
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. It corresponds to ids.values in embedding_lookup(), when ids is a
RaggedTensor. Both int32 and int64 are allowed and will be converted to
int32 internally.
aggregation_weights: A list of rank 1 Tensors containing per training
example aggregation weights. It corresponds to the values field of a
RaggedTensor with the same row_splits as ids in embedding_lookup(), when
ids is a RaggedTensor. Both float32 and float64 are allowed and will be
converted to float32 internally.
table_ids: A list of integers specifying the identifier of the embedding
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
lookup the corresponding input. The ith input is looked up using
table_ids[i]. The size of the table_ids list must be equal to that of
sample_indices, embedding_indices and aggregation_weights.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
max_sequence_lengths: A list of integers, the size of which is equal to
sample_indices. If equal to 0, the corresponding feature is considered to
be a non-sequence feature, If greater than 0, the corresponding feature is
a sequence feature with the given maximal length. If None, then we assume
a list of all zeroes.
combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
invalid to have the sum of the weights be 0 for 'mean' or the sum of the
squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
is to use 'sum' for all tables (optional).
mode_override: A string input that overrides the mode specified in the
TPUEmbeddingConfiguration. Supported values are {'unspecified',
'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
is used (optional).
name: A name for the operation (optional).
Returns:
An EnqueueTPUEmbeddingRaggedTensorBatch operation.
"""
if mode_override is None:
mode_override = "unspecified"
return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
sample_splits=sample_splits,
embedding_indices=embedding_indices,
aggregation_weights=aggregation_weights,
table_ids=table_ids,
device_ordinal=device_ordinal,
max_sequence_lengths=max_sequence_lengths,
combiners=combiners,
mode_override=mode_override,
name=name)
enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)

View File

@ -205,6 +205,48 @@ class EnqueueData(
aggregation_weights=weights.values if weights is not None else None)
class RaggedEnqueueData(
collections.namedtuple(
'RaggedEnqueueData',
['embedding_indices', 'sample_splits', 'aggregation_weights'])):
"""RaggedTensor Data to be enqueued through generate_enqueue_ops()."""
def __new__(cls,
embedding_indices,
sample_splits=None,
aggregation_weights=None):
"""Data to be enqueued through generate_enqueue_ops().
Args:
embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
corresponds to ids.values in embedding_lookup(), when ids is a
RaggedTensor. Both int32 and int64 are allowed and will be converted to
int32 internally.
sample_splits: A rank 1 Tensor specifying the break points for splitting
embedding_indices and aggregation_weights into rows. It corresponds to
ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
int32 and int64 are allowed and will be converted to int32 internally.
aggregation_weights: A rank 1 Tensor containing per training example
aggregation weights. It corresponds to the values field of a
RaggedTensor with the same row_splits as ids in embedding_lookup(), when
ids is a RaggedTensor.
Returns:
An RaggedEnqueueData tuple.
"""
return super(RaggedEnqueueData,
cls).__new__(cls, embedding_indices, sample_splits,
aggregation_weights)
@staticmethod
def from_ragged_tensor(rg_tensor, weights=None):
return RaggedEnqueueData(
rg_tensor.values,
rg_tensor.row_splits,
aggregation_weights=weights.values if weights is not None else None)
def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
"""Convenient function for generate_enqueue_ops().
@ -229,6 +271,30 @@ def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
return enqueue_datas_list
def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
"""Convenient function for generate_enqueue_ops().
Args:
rg_tensors_list: a list of dictionary mapping from string of feature names
to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
same host should be contiguous on the list.
Returns:
enqueue_datas_list: a list of dictionary mapping from string
of feature names to RaggedEnqueueData. Each dictionary is for one
TPU core. Dictionaries for the same host should be contiguous
on the list.
"""
enqueue_datas_list = []
for rg_tensors in rg_tensors_list:
enqueue_datas = collections.OrderedDict(
(k, RaggedEnqueueData.from_ragged_tensor(v))
for k, v in six.iteritems(rg_tensors))
enqueue_datas_list.append(enqueue_datas)
return enqueue_datas_list
AdamSlotVariableNames = collections.namedtuple(
'AdamSlotVariableNames', ['m', 'v'])
@ -1159,7 +1225,12 @@ class TPUEmbedding(object):
slot_variables_by_table,
load_ops, retrieve_ops)
def generate_enqueue_ops(self, enqueue_datas_list, mode_override=None):
def generate_enqueue_ops(
self,
enqueue_datas_list,
mode_override=None,
ragged=False,
):
"""Generate enqueue ops.
Args:
@ -1172,6 +1243,8 @@ class TPUEmbedding(object):
'inference', 'training', 'backward_pass_only'}. When set to
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
otherwise mode_override is used (optional).
ragged: If True, creates RaggedTensor enqueue ops rather than
SparseTensor.
Returns:
Ops to enqueue to TPU for embedding.
@ -1182,6 +1255,7 @@ class TPUEmbedding(object):
enqueue_datas,
device_ordinal=i % self._num_cores_per_host,
mode_override=mode_override,
ragged=ragged,
) for i, enqueue_datas in enumerate(enqueue_datas_list)
]
@ -1211,28 +1285,50 @@ class TPUEmbedding(object):
for feature, enqueue_data in six.iteritems(enqueue_datas):
combiner = self._table_to_config_dict[
self._feature_to_config_dict[feature].table_id].combiner
if not isinstance(enqueue_data, EnqueueData):
raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
'not mapped to `EnqueueData`. `feature`: {}'.format(
i, feature))
if enqueue_data.sample_indices is None and combiner:
logging.warn('No sample indices set for features %f table %f but '
'combiner is set to %s.', feature,
self._feature_to_config_dict[feature].table_id, combiner)
if isinstance(enqueue_data, EnqueueData):
if enqueue_data.sample_indices is None and combiner:
logging.warn(
'No sample indices set for features %f table %f but '
'combiner is set to %s.', feature,
self._feature_to_config_dict[feature].table_id, combiner)
if (enqueue_data.sample_indices is not None and
enqueue_data.sample_indices.device !=
enqueue_data.embedding_indices.device):
raise ValueError(
'Device of sample_indices does not agree with '
'that of embedding_indices for feature {}.'.format(feature))
if (enqueue_data.aggregation_weights is not None and
enqueue_data.aggregation_weights.device !=
enqueue_data.embedding_indices.device):
raise ValueError(
'Device of aggregation_weights does not agree with '
'that of embedding_indices for feature {}.'.format(feature))
if (enqueue_data.sample_indices is not None and
enqueue_data.sample_indices.device !=
enqueue_data.embedding_indices.device):
elif isinstance(enqueue_data, RaggedEnqueueData):
if enqueue_data.sample_splits is None and combiner:
logging.warn(
'No sample splits set for features %f table %f but '
'combiner is set to %s.', feature,
self._feature_to_config_dict[feature].table_id, combiner)
if (enqueue_data.sample_splits is not None and
enqueue_data.sample_splits.device !=
enqueue_data.embedding_indices.device):
raise ValueError(
'Device of sample_splits does not agree with '
'that of embedding_indices for feature {}.'.format(feature))
if (enqueue_data.aggregation_weights is not None and
enqueue_data.aggregation_weights.device !=
enqueue_data.embedding_indices.device):
raise ValueError(
'Device of aggregation_weights does not agree with '
'that of embedding_indices for feature {}.'.format(feature))
else:
raise ValueError(
'Device of sample_indices does not agree with '
'that of embedding_indices for feature {}.'.format(feature))
if (enqueue_data.aggregation_weights is not None and
enqueue_data.aggregation_weights.device !=
enqueue_data.embedding_indices.device):
raise ValueError(
'Device of aggregation_weights does not agree with '
'that of embedding_indices for feature {}.'.format(feature))
'`enqueue_datas_list[{}]` has a feature that is not mapped to '
'`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
i, feature))
# Check all features are on the same device.
if device is None:
device = enqueue_data.embedding_indices.device
@ -1257,23 +1353,69 @@ class TPUEmbedding(object):
else:
contiguous_device = device
def _generate_enqueue_op(
self, enqueue_datas, device_ordinal, mode_override=None):
def _generate_enqueue_op(self,
enqueue_datas,
device_ordinal,
mode_override=None,
ragged=False):
"""Creates op for enqueuing batch to TPU."""
enqueue_data0 = list(enqueue_datas.values())[0]
with ops.colocate_with(enqueue_data0.embedding_indices):
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
device_ordinal=device_ordinal,
combiners=self._combiners,
mode_override=mode_override,
**self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)
)
if ragged:
# note that this is currently identical in behavior
return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
device_ordinal=device_ordinal,
combiners=self._combiners,
mode_override=mode_override,
**self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas))
else:
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
device_ordinal=device_ordinal,
combiners=self._combiners,
mode_override=mode_override,
**self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas):
"""Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`.
Args:
enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
Returns:
Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`.
"""
kwargs = {
'sample_splits': [],
'embedding_indices': [],
'aggregation_weights': [],
'table_ids': [],
'max_sequence_lengths': [],
}
for table_id, table in enumerate(self._table_to_features_dict):
features = self._table_to_features_dict[table]
for feature in features:
enqueue_data = enqueue_datas[feature]
kwargs['sample_splits'].append(enqueue_data.sample_splits)
kwargs['aggregation_weights'].append(
enqueue_data.aggregation_weights if enqueue_data.aggregation_weights
is not None else array_ops.zeros((0,), dtype=dtypes.float32))
kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
kwargs['table_ids'].append(table_id)
kwargs['max_sequence_lengths'].append(
self._feature_to_config_dict[feature].max_sequence_length)
return kwargs
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
Args:
enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
dense.
enqueue_datas: a `Dict` of `EnqueueData` objects for embedding.
Returns:
Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.

View File

@ -1244,6 +1244,10 @@ tf_module {
name: "EnqueueTPUEmbeddingIntegerBatch"
argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
}
member_method {
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
}
member_method {
name: "EnqueueTPUEmbeddingSparseBatch"
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "

View File

@ -1244,6 +1244,10 @@ tf_module {
name: "EnqueueTPUEmbeddingIntegerBatch"
argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
}
member_method {
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
}
member_method {
name: "EnqueueTPUEmbeddingSparseBatch"
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "