Add option for input tensors for TPUEmbedding to have a first dimension which is a multiple of the batch_size.
PiperOrigin-RevId: 354344093 Change-Id: I0fbb6820b9fa0bb0128ed4a27e2f527937741cbf
This commit is contained in:
parent
8a356e8ca5
commit
e994fb8f32
@ -164,6 +164,7 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
|
||||
.Attr("combiners: list(string) = []")
|
||||
.Attr("table_ids: list(int)")
|
||||
.Attr("max_sequence_lengths: list(int) = []")
|
||||
.Attr("num_features: list(int) = []")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
@ -180,6 +181,7 @@ REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
|
||||
.Attr("combiners: list(string) = []")
|
||||
.Attr("table_ids: list(int)")
|
||||
.Attr("max_sequence_lengths: list(int) = []")
|
||||
.Attr("num_features: list(int) = []")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
|
@ -380,6 +380,7 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
||||
table_ids,
|
||||
device_ordinal,
|
||||
max_sequence_lengths=None,
|
||||
num_features=None,
|
||||
combiners=None,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
@ -412,6 +413,11 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
||||
be a non-sequence feature, If greater than 0, the corresponding feature is
|
||||
a sequence feature with the given maximal length. If None, then we assume
|
||||
a list of all zeroes.
|
||||
num_features: A list of integers, the size of which is equal to
|
||||
sample_indices. If non-empty, entries in this list must be at least 1.
|
||||
For each batch element, we will take num_features rows of the input
|
||||
tensor for embedding lookup. E.g., when sample_indices is empty,
|
||||
the embedding indices must be of shape (batch_size*num_features).
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
specify how to normalize the embedding activations after weighted
|
||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||
@ -439,6 +445,7 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
combiners=combiners,
|
||||
mode_override=mode_override,
|
||||
num_features=num_features,
|
||||
name=name)
|
||||
|
||||
|
||||
@ -453,6 +460,7 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
|
||||
table_ids,
|
||||
device_ordinal,
|
||||
max_sequence_lengths=None,
|
||||
num_features=None,
|
||||
combiners=None,
|
||||
mode_override=None,
|
||||
name=None):
|
||||
@ -485,6 +493,11 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
|
||||
be a non-sequence feature, If greater than 0, the corresponding feature is
|
||||
a sequence feature with the given maximal length. If None, then we assume
|
||||
a list of all zeroes.
|
||||
num_features: A list of integers, the size of which must be equal to
|
||||
sample_indices. If non-empty, entries in this list must be at least 1.
|
||||
For each batch element, we will take num_features rows of the input
|
||||
tensor for embedding lookup. E.g., when sample_indices is empty,
|
||||
the embedding indices must be of shape (batch_size*num_features).
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
specify how to normalize the embedding activations after weighted
|
||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||
@ -512,6 +525,7 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
combiners=combiners,
|
||||
mode_override=mode_override,
|
||||
num_features=num_features,
|
||||
name=name)
|
||||
|
||||
|
||||
|
@ -1346,7 +1346,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
||||
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingSparseBatch"
|
||||
@ -1354,7 +1354,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
||||
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
||||
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EnsureShape"
|
||||
|
@ -1346,7 +1346,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
||||
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingSparseBatch"
|
||||
@ -1354,7 +1354,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
||||
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
||||
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "EnsureShape"
|
||||
|
Loading…
x
Reference in New Issue
Block a user