diff --git a/tensorflow/core/ops/tpu_embedding_ops.cc b/tensorflow/core/ops/tpu_embedding_ops.cc index 79ebc09adc2..4eaab1e6c72 100644 --- a/tensorflow/core/ops/tpu_embedding_ops.cc +++ b/tensorflow/core/ops/tpu_embedding_ops.cc @@ -393,10 +393,13 @@ REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch") .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("EnqueueTPUEmbeddingSparseBatch") - .Input("sample_indices: N * int32") - .Input("embedding_indices: N * int32") - .Input("aggregation_weights: N * float32") + .Input("sample_indices: N * T1") + .Input("embedding_indices: N * T2") + .Input("aggregation_weights: N * T3") .Input("mode_override: string") + .Attr("T1: {int32,int64} = DT_INT32") + .Attr("T2: {int32,int64} = DT_INT32") + .Attr("T3: {float32,float64} = DT_FLOAT") .Attr("N: int >= 1") .Attr("device_ordinal: int = -1") .Attr("combiners: list(string) = []") @@ -416,10 +419,13 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch") }); REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch") - .Input("sample_indices: N * int32") - .Input("embedding_indices: N * int32") - .Input("aggregation_weights: N * float32") + .Input("sample_indices: N * T1") + .Input("embedding_indices: N * T2") + .Input("aggregation_weights: N * T3") .Input("mode_override: string") + .Attr("T1: {int32,int64} = DT_INT32") + .Attr("T2: {int32,int64} = DT_INT32") + .Attr("T3: {float32,float64} = DT_FLOAT") .Attr("N: int >= 1") .Attr("device_ordinal: int = -1") .Attr("combiners: list(string) = []") diff --git a/tensorflow/python/tpu/ops/tpu_ops.py b/tensorflow/python/tpu/ops/tpu_ops.py index 38dd2734ac2..678d504740c 100644 --- a/tensorflow/python/tpu/ops/tpu_ops.py +++ b/tensorflow/python/tpu/ops/tpu_ops.py @@ -328,11 +328,14 @@ def enqueue_tpu_embedding_sparse_batch(sample_indices, and feature to which the corresponding embedding_indices and aggregation_weights values belong. sample_indices[i] must equal b * nf + f, where nf is the number of features from the corresponding table, f is - in [0, nf), and b is in [0, batch size). + in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed, + and will be converted to int32 internally. embedding_indices: A list of rank 1 Tensors, indices into the embedding - tables. + tables. Both int32 and int64 are allowed and will be converted to int32 + internally. aggregation_weights: A list of rank 1 Tensors containing per sample -- - i.e. per (training example, feature) -- aggregation weights. + i.e. per (training example, feature) -- aggregation weights. Both float32 + and float64 are allowed and will be converted to float32 internally. device_ordinal: The TPU device to use. Should be >= 0 and less than the number of TPU cores in the task on which the node is placed. combiners: A list of string scalars, one for each embedding table that @@ -382,12 +385,15 @@ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, sample_indices: A list of rank 1 Tensors specifying the training example to which the corresponding embedding_indices and aggregation_weights values belong. It corresponds to sp_ids.indices[:,0] in - embedding_lookup_sparse(). + embedding_lookup_sparse(). Both int32 and int64 are allowed and will be + converted to int32 internally. embedding_indices: A list of rank 1 Tensors, indices into the embedding - tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). + tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both + int32 and int64 are allowed and will be converted to int32 internally. aggregation_weights: A list of rank 1 Tensors containing per training example aggregation weights. It corresponds to sp_weights.values in - embedding_lookup_sparse(). + embedding_lookup_sparse(). Both float32 and float64 are allowed and will + be converted to float32 internally. table_ids: A list of integers specifying the identifier of the embedding table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the corresponding input. The ith input is looked up using