Support both int32/int64 and float32/float64 in TPU embedding enqueue ops.
PiperOrigin-RevId: 236421400
This commit is contained in:
parent
a9064fdddc
commit
d63fb26158
@ -393,10 +393,13 @@ REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
|
|||||||
.SetShapeFn(shape_inference::UnknownShape);
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
|
REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
|
||||||
.Input("sample_indices: N * int32")
|
.Input("sample_indices: N * T1")
|
||||||
.Input("embedding_indices: N * int32")
|
.Input("embedding_indices: N * T2")
|
||||||
.Input("aggregation_weights: N * float32")
|
.Input("aggregation_weights: N * T3")
|
||||||
.Input("mode_override: string")
|
.Input("mode_override: string")
|
||||||
|
.Attr("T1: {int32,int64} = DT_INT32")
|
||||||
|
.Attr("T2: {int32,int64} = DT_INT32")
|
||||||
|
.Attr("T3: {float32,float64} = DT_FLOAT")
|
||||||
.Attr("N: int >= 1")
|
.Attr("N: int >= 1")
|
||||||
.Attr("device_ordinal: int = -1")
|
.Attr("device_ordinal: int = -1")
|
||||||
.Attr("combiners: list(string) = []")
|
.Attr("combiners: list(string) = []")
|
||||||
@ -416,10 +419,13 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
|
|||||||
});
|
});
|
||||||
|
|
||||||
REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
|
REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
|
||||||
.Input("sample_indices: N * int32")
|
.Input("sample_indices: N * T1")
|
||||||
.Input("embedding_indices: N * int32")
|
.Input("embedding_indices: N * T2")
|
||||||
.Input("aggregation_weights: N * float32")
|
.Input("aggregation_weights: N * T3")
|
||||||
.Input("mode_override: string")
|
.Input("mode_override: string")
|
||||||
|
.Attr("T1: {int32,int64} = DT_INT32")
|
||||||
|
.Attr("T2: {int32,int64} = DT_INT32")
|
||||||
|
.Attr("T3: {float32,float64} = DT_FLOAT")
|
||||||
.Attr("N: int >= 1")
|
.Attr("N: int >= 1")
|
||||||
.Attr("device_ordinal: int = -1")
|
.Attr("device_ordinal: int = -1")
|
||||||
.Attr("combiners: list(string) = []")
|
.Attr("combiners: list(string) = []")
|
||||||
|
@ -328,11 +328,14 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
|
|||||||
and feature to which the corresponding embedding_indices and
|
and feature to which the corresponding embedding_indices and
|
||||||
aggregation_weights values belong. sample_indices[i] must equal b * nf +
|
aggregation_weights values belong. sample_indices[i] must equal b * nf +
|
||||||
f, where nf is the number of features from the corresponding table, f is
|
f, where nf is the number of features from the corresponding table, f is
|
||||||
in [0, nf), and b is in [0, batch size).
|
in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
|
||||||
|
and will be converted to int32 internally.
|
||||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||||
tables.
|
tables. Both int32 and int64 are allowed and will be converted to int32
|
||||||
|
internally.
|
||||||
aggregation_weights: A list of rank 1 Tensors containing per sample --
|
aggregation_weights: A list of rank 1 Tensors containing per sample --
|
||||||
i.e. per (training example, feature) -- aggregation weights.
|
i.e. per (training example, feature) -- aggregation weights. Both float32
|
||||||
|
and float64 are allowed and will be converted to float32 internally.
|
||||||
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
device_ordinal: The TPU device to use. Should be >= 0 and less than the
|
||||||
number of TPU cores in the task on which the node is placed.
|
number of TPU cores in the task on which the node is placed.
|
||||||
combiners: A list of string scalars, one for each embedding table that
|
combiners: A list of string scalars, one for each embedding table that
|
||||||
@ -382,12 +385,15 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
|
|||||||
sample_indices: A list of rank 1 Tensors specifying the training example
|
sample_indices: A list of rank 1 Tensors specifying the training example
|
||||||
to which the corresponding embedding_indices and aggregation_weights
|
to which the corresponding embedding_indices and aggregation_weights
|
||||||
values belong. It corresponds to sp_ids.indices[:,0] in
|
values belong. It corresponds to sp_ids.indices[:,0] in
|
||||||
embedding_lookup_sparse().
|
embedding_lookup_sparse(). Both int32 and int64 are allowed and will be
|
||||||
|
converted to int32 internally.
|
||||||
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
embedding_indices: A list of rank 1 Tensors, indices into the embedding
|
||||||
tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
|
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
|
||||||
|
int32 and int64 are allowed and will be converted to int32 internally.
|
||||||
aggregation_weights: A list of rank 1 Tensors containing per training
|
aggregation_weights: A list of rank 1 Tensors containing per training
|
||||||
example aggregation weights. It corresponds to sp_weights.values in
|
example aggregation weights. It corresponds to sp_weights.values in
|
||||||
embedding_lookup_sparse().
|
embedding_lookup_sparse(). Both float32 and float64 are allowed and will
|
||||||
|
be converted to float32 internally.
|
||||||
table_ids: A list of integers specifying the identifier of the embedding
|
table_ids: A list of integers specifying the identifier of the embedding
|
||||||
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
|
||||||
lookup the corresponding input. The ith input is looked up using
|
lookup the corresponding input. The ith input is looked up using
|
||||||
|
Loading…
Reference in New Issue
Block a user