Support both int32/int64 and float32/float64 in TPU embedding enqueue ops.

PiperOrigin-RevId: 236421400
This commit is contained in:
A. Unique TensorFlower 2019-03-01 20:21:36 -08:00 committed by TensorFlower Gardener
parent a9064fdddc
commit d63fb26158
2 changed files with 24 additions and 12 deletions

View File

@ -393,10 +393,13 @@ REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
.SetShapeFn(shape_inference::UnknownShape); .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("EnqueueTPUEmbeddingSparseBatch") REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
.Input("sample_indices: N * int32") .Input("sample_indices: N * T1")
.Input("embedding_indices: N * int32") .Input("embedding_indices: N * T2")
.Input("aggregation_weights: N * float32") .Input("aggregation_weights: N * T3")
.Input("mode_override: string") .Input("mode_override: string")
.Attr("T1: {int32,int64} = DT_INT32")
.Attr("T2: {int32,int64} = DT_INT32")
.Attr("T3: {float32,float64} = DT_FLOAT")
.Attr("N: int >= 1") .Attr("N: int >= 1")
.Attr("device_ordinal: int = -1") .Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []") .Attr("combiners: list(string) = []")
@ -416,10 +419,13 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
}); });
REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch") REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
.Input("sample_indices: N * int32") .Input("sample_indices: N * T1")
.Input("embedding_indices: N * int32") .Input("embedding_indices: N * T2")
.Input("aggregation_weights: N * float32") .Input("aggregation_weights: N * T3")
.Input("mode_override: string") .Input("mode_override: string")
.Attr("T1: {int32,int64} = DT_INT32")
.Attr("T2: {int32,int64} = DT_INT32")
.Attr("T3: {float32,float64} = DT_FLOAT")
.Attr("N: int >= 1") .Attr("N: int >= 1")
.Attr("device_ordinal: int = -1") .Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []") .Attr("combiners: list(string) = []")

View File

@ -328,11 +328,14 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices,
and feature to which the corresponding embedding_indices and and feature to which the corresponding embedding_indices and
aggregation_weights values belong. sample_indices[i] must equal b * nf + aggregation_weights values belong. sample_indices[i] must equal b * nf +
f, where nf is the number of features from the corresponding table, f is f, where nf is the number of features from the corresponding table, f is
in [0, nf), and b is in [0, batch size). in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
and will be converted to int32 internally.
embedding_indices: A list of rank 1 Tensors, indices into the embedding embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. tables. Both int32 and int64 are allowed and will be converted to int32
internally.
aggregation_weights: A list of rank 1 Tensors containing per sample -- aggregation_weights: A list of rank 1 Tensors containing per sample --
i.e. per (training example, feature) -- aggregation weights. i.e. per (training example, feature) -- aggregation weights. Both float32
and float64 are allowed and will be converted to float32 internally.
device_ordinal: The TPU device to use. Should be >= 0 and less than the device_ordinal: The TPU device to use. Should be >= 0 and less than the
number of TPU cores in the task on which the node is placed. number of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that combiners: A list of string scalars, one for each embedding table that
@ -382,12 +385,15 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
sample_indices: A list of rank 1 Tensors specifying the training example sample_indices: A list of rank 1 Tensors specifying the training example
to which the corresponding embedding_indices and aggregation_weights to which the corresponding embedding_indices and aggregation_weights
values belong. It corresponds to sp_ids.indices[:,0] in values belong. It corresponds to sp_ids.indices[:,0] in
embedding_lookup_sparse(). embedding_lookup_sparse(). Both int32 and int64 are allowed and will be
converted to int32 internally.
embedding_indices: A list of rank 1 Tensors, indices into the embedding embedding_indices: A list of rank 1 Tensors, indices into the embedding
tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
int32 and int64 are allowed and will be converted to int32 internally.
aggregation_weights: A list of rank 1 Tensors containing per training aggregation_weights: A list of rank 1 Tensors containing per training
example aggregation weights. It corresponds to sp_weights.values in example aggregation weights. It corresponds to sp_weights.values in
embedding_lookup_sparse(). embedding_lookup_sparse(). Both float32 and float64 are allowed and will
be converted to float32 internally.
table_ids: A list of integers specifying the identifier of the embedding table_ids: A list of integers specifying the identifier of the embedding
table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
lookup the corresponding input. The ith input is looked up using lookup the corresponding input. The ith input is looked up using