Add option for input tensors for TPUEmbedding to have a first dimension which is a multiple of the batch_size.
PiperOrigin-RevId: 354344093 Change-Id: I0fbb6820b9fa0bb0128ed4a27e2f527937741cbf
This commit is contained in:
parent
8a356e8ca5
commit
e994fb8f32
@ -164,6 +164,7 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
|
|||||||
.Attr("combiners: list(string) = []")
|
.Attr("combiners: list(string) = []")
|
||||||
.Attr("table_ids: list(int)")
|
.Attr("table_ids: list(int)")
|
||||||
.Attr("max_sequence_lengths: list(int) = []")
|
.Attr("max_sequence_lengths: list(int) = []")
|
||||||
|
.Attr("num_features: list(int) = []")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(shape_inference::UnknownShape);
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
@ -180,6 +181,7 @@ REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
|
|||||||
.Attr("combiners: list(string) = []")
|
.Attr("combiners: list(string) = []")
|
||||||
.Attr("table_ids: list(int)")
|
.Attr("table_ids: list(int)")
|
||||||
.Attr("max_sequence_lengths: list(int) = []")
|
.Attr("max_sequence_lengths: list(int) = []")
|
||||||
|
.Attr("num_features: list(int) = []")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(shape_inference::UnknownShape);
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
|
@ -380,6 +380,7 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
|||||||
table_ids,
|
table_ids,
|
||||||
device_ordinal,
|
device_ordinal,
|
||||||
max_sequence_lengths=None,
|
max_sequence_lengths=None,
|
||||||
|
num_features=None,
|
||||||
combiners=None,
|
combiners=None,
|
||||||
mode_override=None,
|
mode_override=None,
|
||||||
name=None):
|
name=None):
|
||||||
@ -412,6 +413,11 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
|||||||
be a non-sequence feature, If greater than 0, the corresponding feature is
|
be a non-sequence feature, If greater than 0, the corresponding feature is
|
||||||
a sequence feature with the given maximal length. If None, then we assume
|
a sequence feature with the given maximal length. If None, then we assume
|
||||||
a list of all zeroes.
|
a list of all zeroes.
|
||||||
|
num_features: A list of integers, the size of which is equal to
|
||||||
|
sample_indices. If non-empty, entries in this list must be at least 1.
|
||||||
|
For each batch element, we will take num_features rows of the input
|
||||||
|
tensor for embedding lookup. E.g., when sample_indices is empty,
|
||||||
|
the embedding indices must be of shape (batch_size*num_features).
|
||||||
combiners: A list of string scalars, one for each embedding table that
|
combiners: A list of string scalars, one for each embedding table that
|
||||||
specify how to normalize the embedding activations after weighted
|
specify how to normalize the embedding activations after weighted
|
||||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||||
@ -439,6 +445,7 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
|||||||
max_sequence_lengths=max_sequence_lengths,
|
max_sequence_lengths=max_sequence_lengths,
|
||||||
combiners=combiners,
|
combiners=combiners,
|
||||||
mode_override=mode_override,
|
mode_override=mode_override,
|
||||||
|
num_features=num_features,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
@ -453,6 +460,7 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
|
|||||||
table_ids,
|
table_ids,
|
||||||
device_ordinal,
|
device_ordinal,
|
||||||
max_sequence_lengths=None,
|
max_sequence_lengths=None,
|
||||||
|
num_features=None,
|
||||||
combiners=None,
|
combiners=None,
|
||||||
mode_override=None,
|
mode_override=None,
|
||||||
name=None):
|
name=None):
|
||||||
@ -485,6 +493,11 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
|
|||||||
be a non-sequence feature, If greater than 0, the corresponding feature is
|
be a non-sequence feature, If greater than 0, the corresponding feature is
|
||||||
a sequence feature with the given maximal length. If None, then we assume
|
a sequence feature with the given maximal length. If None, then we assume
|
||||||
a list of all zeroes.
|
a list of all zeroes.
|
||||||
|
num_features: A list of integers, the size of which must be equal to
|
||||||
|
sample_indices. If non-empty, entries in this list must be at least 1.
|
||||||
|
For each batch element, we will take num_features rows of the input
|
||||||
|
tensor for embedding lookup. E.g., when sample_indices is empty,
|
||||||
|
the embedding indices must be of shape (batch_size*num_features).
|
||||||
combiners: A list of string scalars, one for each embedding table that
|
combiners: A list of string scalars, one for each embedding table that
|
||||||
specify how to normalize the embedding activations after weighted
|
specify how to normalize the embedding activations after weighted
|
||||||
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
|
||||||
@ -512,6 +525,7 @@ def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
|
|||||||
max_sequence_lengths=max_sequence_lengths,
|
max_sequence_lengths=max_sequence_lengths,
|
||||||
combiners=combiners,
|
combiners=combiners,
|
||||||
mode_override=mode_override,
|
mode_override=mode_override,
|
||||||
|
num_features=num_features,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1346,7 +1346,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||||
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EnqueueTPUEmbeddingSparseBatch"
|
name: "EnqueueTPUEmbeddingSparseBatch"
|
||||||
@ -1354,7 +1354,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
||||||
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EnsureShape"
|
name: "EnsureShape"
|
||||||
|
@ -1346,7 +1346,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
name: "EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||||
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EnqueueTPUEmbeddingSparseBatch"
|
name: "EnqueueTPUEmbeddingSparseBatch"
|
||||||
@ -1354,7 +1354,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
||||||
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
|
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'num_features\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "EnsureShape"
|
name: "EnsureShape"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user