Add option for input tensors for TPUEmbedding to have a first dimension which is a multiple of the batch_size.

PiperOrigin-RevId: 354344093
Change-Id: I0fbb6820b9fa0bb0128ed4a27e2f527937741cbf
This commit is contained in:
Bruce Fontaine 2021-01-28 10:43:09 -08:00 committed by TensorFlower Gardener
parent 8a356e8ca5
commit e994fb8f32
4 changed files with 20 additions and 4 deletions

View File

@ -164,6 +164,7 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
.Attr("combiners: list(string) = []") .Attr("combiners: list(string) = []")
.Attr("table_ids: list(int)") .Attr("table_ids: list(int)")
.Attr("max_sequence_lengths: list(int) = []") .Attr("max_sequence_lengths: list(int) = []")
.Attr("num_features: list(int) = []")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape); .SetShapeFn(shape_inference::UnknownShape);
@ -180,6 +181,7 @@ REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
.Attr("combiners: list(string) = []") .Attr("combiners: list(string) = []")
.Attr("table_ids: list(int)") .Attr("table_ids: list(int)")
.Attr("max_sequence_lengths: list(int) = []") .Attr("max_sequence_lengths: list(int) = []")
.Attr("num_features: list(int) = []")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape); .SetShapeFn(shape_inference::UnknownShape);

View File

@ -380,6 +380,7 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
table_ids, table_ids,
device_ordinal, device_ordinal,
max_sequence_lengths=None, max_sequence_lengths=None,
num_features=None,
combiners=None, combiners=None,
mode_override=None, mode_override=None,
name=None): name=None):
@ -412,6 +413,11 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
be a non-sequence feature, If greater than 0, the corresponding feature is be a non-sequence feature, If greater than 0, the corresponding feature is
a sequence feature with the given maximal length. If None, then we assume a sequence feature with the given maximal length. If None, then we assume
a list of all zeroes. a list of all zeroes.
num_features: A list of integers, the size of which is equal to
sample_indices. If non-empty, entries in this list must be at least 1.
For each batch element, we will take num_features rows of the input
tensor for embedding lookup. E.g., when sample_indices is empty,
the embedding indices must be of shape (batch_size*num_features).
combiners: A list of string scalars, one for each embedding table that combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
@ -439,6 +445,7 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
max_sequence_lengths=max_sequence_lengths, max_sequence_lengths=max_sequence_lengths,
combiners=combiners, combiners=combiners,
mode_override=mode_override, mode_override=mode_override,
num_features=num_features,
name=name) name=name)
@ -453,6 +460,7 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
table_ids, table_ids,
device_ordinal, device_ordinal,
max_sequence_lengths=None, max_sequence_lengths=None,
num_features=None,
combiners=None, combiners=None,
mode_override=None, mode_override=None,
name=None): name=None):
@ -485,6 +493,11 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
be a non-sequence feature, If greater than 0, the corresponding feature is be a non-sequence feature, If greater than 0, the corresponding feature is
a sequence feature with the given maximal length. If None, then we assume a sequence feature with the given maximal length. If None, then we assume
a list of all zeroes. a list of all zeroes.
num_features: A list of integers, the size of which must be equal to
sample_indices. If non-empty, entries in this list must be at least 1.
For each batch element, we will take num_features rows of the input
tensor for embedding lookup. E.g., when sample_indices is empty,
the embedding indices must be of shape (batch_size*num_features).
combiners: A list of string scalars, one for each embedding table that combiners: A list of string scalars, one for each embedding table that
specify how to normalize the embedding activations after weighted specify how to normalize the embedding activations after weighted
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
@ -512,6 +525,7 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
max_sequence_lengths=max_sequence_lengths, max_sequence_lengths=max_sequence_lengths,
combiners=combiners, combiners=combiners,
mode_override=mode_override, mode_override=mode_override,
num_features=num_features,
name=name) name=name)

View File

@ -1346,7 +1346,7 @@ tf_module {
} }
member_method { member_method {
name: "EnqueueTPUEmbeddingRaggedTensorBatch" name: "EnqueueTPUEmbeddingRaggedTensorBatch"
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], " argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
} }
member_method { member_method {
name: "EnqueueTPUEmbeddingSparseBatch" name: "EnqueueTPUEmbeddingSparseBatch"
@ -1354,7 +1354,7 @@ tf_module {
} }
member_method { member_method {
name: "EnqueueTPUEmbeddingSparseTensorBatch" name: "EnqueueTPUEmbeddingSparseTensorBatch"
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], " argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
} }
member_method { member_method {
name: "EnsureShape" name: "EnsureShape"

View File

@ -1346,7 +1346,7 @@ tf_module {
} }
member_method { member_method {
name: "EnqueueTPUEmbeddingRaggedTensorBatch" name: "EnqueueTPUEmbeddingRaggedTensorBatch"
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], " argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
} }
member_method { member_method {
name: "EnqueueTPUEmbeddingSparseBatch" name: "EnqueueTPUEmbeddingSparseBatch"
@ -1354,7 +1354,7 @@ tf_module {
} }
member_method { member_method {
name: "EnqueueTPUEmbeddingSparseTensorBatch" name: "EnqueueTPUEmbeddingSparseTensorBatch"
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], " argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
} }
member_method { member_method {
name: "EnsureShape" name: "EnsureShape"