Add EnqueueTPUEmbeddingSparseTensorBatch and TF_EnqueueTPUEmbeddingRaggedTensorBatch ops to TF MLIR dialect.

PiperOrigin-RevId: 313673623
Change-Id: I5a0c803e55a036a40d0ef2a9469895cddec15932
This commit is contained in:
A. Unique TensorFlower 2020-05-28 15:43:40 -07:00 committed by TensorFlower Gardener
parent 5c77174291
commit 3e842e5ffc

View File

@ -2670,6 +2670,76 @@ This operation creates a tensor of `shape` and `dtype`.
let hasFolder = 1;
}
def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [SameVariadicOperandSize]> {
let summary = "Eases the porting of code that uses tf.nn.embedding_lookup().";
let description = [{
sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond
to the ith feature. table_ids[i] indicates which embedding table to look up ith
feature.
The tensors at corresponding positions in two of the input lists,
embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1
with dim_size() equal to the total number of lookups into the table described by
the corresponding feature.
}];
let arguments = (ins
Variadic<TF_I32OrI64Tensor>:$sample_splits,
Variadic<TF_I32OrI64Tensor>:$embedding_indices,
Variadic<TF_F32OrF64Tensor>:$aggregation_weights,
TF_StrTensor:$mode_override,
DefaultValuedAttr<I64Attr, "-1">:$device_ordinal,
DefaultValuedAttr<StrArrayAttr, "{}">:$combiners,
I64ArrayAttr:$table_ids,
DefaultValuedAttr<I64ArrayAttr, "{}">:$max_sequence_lengths
);
let results = (outs);
TF_DerivedOperandTypeAttr T1 = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr T2 = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [SameVariadicOperandSize]> {
let summary = [{
Eases the porting of code that uses tf.nn.embedding_lookup_sparse().
}];
let description = [{
sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond
to the ith feature. table_ids[i] indicates which embedding table to look up ith
feature.
The tensors at corresponding positions in the three input lists (sample_indices,
embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1
with dim_size() equal to the total number of lookups into the table described by
the corresponding feature.
}];
let arguments = (ins
Variadic<TF_I32OrI64Tensor>:$sample_indices,
Variadic<TF_I32OrI64Tensor>:$embedding_indices,
Variadic<TF_F32OrF64Tensor>:$aggregation_weights,
TF_StrTensor:$mode_override,
DefaultValuedAttr<I64Attr, "-1">:$device_ordinal,
DefaultValuedAttr<StrArrayAttr, "{}">:$combiners,
I64ArrayAttr:$table_ids,
DefaultValuedAttr<I64ArrayAttr, "{}">:$max_sequence_lengths
);
let results = (outs);
TF_DerivedOperandTypeAttr T1 = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr T2 = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> {
let summary = "Returns the truth value of (x == y) element-wise.";