Support both int32/int64 and float32/float64 in TPU embedding enqueue ops.
PiperOrigin-RevId: 236421400
This commit is contained in:
parent
a9064fdddc
commit
d63fb26158
@ -393,10 +393,13 @@ REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
|
||||
.Input("sample_indices: N * int32")
|
||||
.Input("embedding_indices: N * int32")
|
||||
.Input("aggregation_weights: N * float32")
|
||||
.Input("sample_indices: N * T1")
|
||||
.Input("embedding_indices: N * T2")
|
||||
.Input("aggregation_weights: N * T3")
|
||||
.Input("mode_override: string")
|
||||
.Attr("T1: {int32,int64} = DT_INT32")
|
||||
.Attr("T2: {int32,int64} = DT_INT32")
|
||||
.Attr("T3: {float32,float64} = DT_FLOAT")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.Attr("combiners: list(string) = []")
|
||||
@ -416,10 +419,13 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
|
||||
});
|
||||
|
||||
REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
|
||||
.Input("sample_indices: N * int32")
|
||||
.Input("embedding_indices: N * int32")
|
||||
.Input("aggregation_weights: N * float32")
|
||||
.Input("sample_indices: N * T1")
|
||||
.Input("embedding_indices: N * T2")
|
||||
.Input("aggregation_weights: N * T3")
|
||||
.Input("mode_override: string")
|
||||
.Attr("T1: {int32,int64} = DT_INT32")
|
||||
.Attr("T2: {int32,int64} = DT_INT32")
|
||||
.Attr("T3: {float32,float64} = DT_FLOAT")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.Attr("combiners: list(string) = []")
|
||||
|
@ -328,11 +328,14 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
|
||||
and feature to which the corresponding embedding_indices and
|
||||
aggregation_weights values belong. sample_indices[i] must equal b * nf +
|
||||
f, where nf is the number of features from the corresponding table, f is
|
||||
in [0, nf), and b is in [0, batch size).
|
||||
in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
|
||||
and will be converted to int32 internally.
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables.
|
||||
tables. Both int32 and int64 are allowed and will be converted to int32
|
||||
internally.
|
||||
aggregation_weights: A list of rank 1 Tensors containing per sample --
|
||||
i.e. per (training example, feature) -- aggregation weights.
|
||||
i.e. per (training example, feature) -- aggregation weights. Both float32
|
||||
and float64 are allowed and will be converted to float32 internally.
|
||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||
number of TPU cores in the task on which the node is placed.
|
||||
combiners: A list of string scalars, one for each embedding table that
|
||||
@ -382,12 +385,15 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
||||
sample_indices: A list of rank 1 Tensors specifying the training example
|
||||
to which the corresponding embedding_indices and aggregation_weights
|
||||
values belong. It corresponds to sp_ids.indices[:,0] in
|
||||
embedding_lookup_sparse().
|
||||
embedding_lookup_sparse(). Both int32 and int64 are allowed and will be
|
||||
converted to int32 internally.
|
||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||
tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
|
||||
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
|
||||
int32 and int64 are allowed and will be converted to int32 internally.
|
||||
aggregation_weights: A list of rank 1 Tensors containing per training
|
||||
example aggregation weights. It corresponds to sp_weights.values in
|
||||
embedding_lookup_sparse().
|
||||
embedding_lookup_sparse(). Both float32 and float64 are allowed and will
|
||||
be converted to float32 internally.
|
||||
table_ids: A list of integers specifying the identifier of the embedding
|
||||
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
||||
lookup the corresponding input. The ith input is looked up using
|
||||
|
Loading…
Reference in New Issue
Block a user