Support both int32/int64 and float32/float64 in TPU embedding enqueue ops.

PiperOrigin-RevId: 236421400
This commit is contained in:
A. Unique TensorFlower 2019-03-01 20:21:36 -08:00 committed by TensorFlower Gardener
parent a9064fdddc
commit d63fb26158
2 changed files with 24 additions and 12 deletions

View File

@ -393,10 +393,13 @@ REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
.Input("sample_indices: N * int32")
.Input("embedding_indices: N * int32")
.Input("aggregation_weights: N * float32")
.Input("sample_indices: N * T1")
.Input("embedding_indices: N * T2")
.Input("aggregation_weights: N * T3")
.Input("mode_override: string")
.Attr("T1: {int32,int64} = DT_INT32")
.Attr("T2: {int32,int64} = DT_INT32")
.Attr("T3: {float32,float64} = DT_FLOAT")
.Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")
@ -416,10 +419,13 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
});
REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
.Input("sample_indices: N * int32")
.Input("embedding_indices: N * int32")
.Input("aggregation_weights: N * float32")
.Input("sample_indices: N * T1")
.Input("embedding_indices: N * T2")
.Input("aggregation_weights: N * T3")
.Input("mode_override: string")
.Attr("T1: {int32,int64} = DT_INT32")
.Attr("T2: {int32,int64} = DT_INT32")
.Attr("T3: {float32,float64} = DT_FLOAT")
.Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")

View File

@ -328,11 +328,14 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
and feature to which the corresponding embedding_indices and
aggregation_weights values belong. sample_indices[i] must equal b * nf +
f, where nf is the number of features from the corresponding table, f is
in [0, nf), and b is in [0, batch size).
in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
and will be converted to int32 internally.
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables.
tables. Both int32 and int64 are allowed and will be converted to int32
internally.
aggregation_weights: A list of rank 1 Tensors containing per sample --
i.e. per (training example, feature) -- aggregation weights.
i.e. per (training example, feature) -- aggregation weights. Both float32
and float64 are allowed and will be converted to float32 internally.
device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that
@ -382,12 +385,15 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
sample_indices: A list of rank 1 Tensors specifying the training example
to which the corresponding embedding_indices and aggregation_weights
values belong. It corresponds to sp_ids.indices[:,0] in
embedding_lookup_sparse().
embedding_lookup_sparse(). Both int32 and int64 are allowed and will be
converted to int32 internally.
embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
int32 and int64 are allowed and will be converted to int32 internally.
aggregation_weights: A list of rank 1 Tensors containing per training
example aggregation weights. It corresponds to sp_weights.values in
embedding_lookup_sparse().
embedding_lookup_sparse(). Both float32 and float64 are allowed and will
be converted to float32 internally.
table_ids: A list of integers specifying the identifier of the embedding
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
lookup the corresponding input. The ith input is looked up using