Add option for input tensors for TPUEmbedding to have a first dimension which is a multiple of the batch_size.

PiperOrigin-RevId: 354344093
Change-Id: I0fbb6820b9fa0bb0128ed4a27e2f527937741cbf
This commit is contained in:
Bruce Fontaine 2021-01-28 10:43:09 -08:00 committed by TensorFlower Gardener
parent 8a356e8ca5
commit e994fb8f32
4 changed files with 20 additions and 4 deletions

View File

@ -164,6 +164,7 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
.Attr("combiners: list(string) = []")
.Attr("table_ids: list(int)")
.Attr("max_sequence_lengths: list(int) = []")
.Attr("num_features: list(int) = []")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
@ -180,6 +181,7 @@ REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
.Attr("combiners: list(string) = []")
.Attr("table_ids: list(int)")
.Attr("max_sequence_lengths: list(int) = []")
.Attr("num_features: list(int) = []")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);

View File

@ -380,6 +380,7 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
table_ids,
device_ordinal,
max_sequence_lengths=None,
num_features=None,
combiners=None,
mode_override=None,
name=None):
@ -412,6 +413,11 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
be a non-sequence feature, If greater than 0, the corresponding feature is
a sequence feature with the given maximal length. If None, then we assume
a list of all zeroes.
num_features: A list of integers, the size of which is equal to
sample_indices. If non-empty, entries in this list must be at least 1.
For each batch element, we will take num_features rows of the input
tensor for embedding lookup. E.g., when sample_indices is empty,
the embedding indices must be of shape (batch_size*num_features).
combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
@ -439,6 +445,7 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
max_sequence_lengths=max_sequence_lengths,
combiners=combiners,
mode_override=mode_override,
num_features=num_features,
name=name)
@ -453,6 +460,7 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
table_ids,
device_ordinal,
max_sequence_lengths=None,
num_features=None,
combiners=None,
mode_override=None,
name=None):
@ -485,6 +493,11 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
be a non-sequence feature, If greater than 0, the corresponding feature is
a sequence feature with the given maximal length. If None, then we assume
a list of all zeroes.
num_features: A list of integers, the size of which must be equal to
sample_indices. If non-empty, entries in this list must be at least 1.
For each batch element, we will take num_features rows of the input
tensor for embedding lookup. E.g., when sample_indices is empty,
the embedding indices must be of shape (batch_size*num_features).
combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
@ -512,6 +525,7 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
max_sequence_lengths=max_sequence_lengths,
combiners=combiners,
mode_override=mode_override,
num_features=num_features,
name=name)

View File

@ -1346,7 +1346,7 @@ tf_module {
}
member_method {
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
}
member_method {
name: "EnqueueTPUEmbeddingSparseBatch"
@ -1354,7 +1354,7 @@ tf_module {
}
member_method {
name: "EnqueueTPUEmbeddingSparseTensorBatch"
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
}
member_method {
name: "EnsureShape"

View File

@ -1346,7 +1346,7 @@ tf_module {
}
member_method {
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
}
member_method {
name: "EnqueueTPUEmbeddingSparseBatch"
@ -1354,7 +1354,7 @@ tf_module {
}
member_method {
name: "EnqueueTPUEmbeddingSparseTensorBatch"
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
}
member_method {
name: "EnsureShape"