Add EnqueueTPUEmbeddingSparseTensorBatch and TF_EnqueueTPUEmbeddingRaggedTensorBatch ops to TF MLIR dialect.
PiperOrigin-RevId: 313673623 Change-Id: I5a0c803e55a036a40d0ef2a9469895cddec15932
This commit is contained in:
parent
5c77174291
commit
3e842e5ffc
@ -2670,6 +2670,76 @@ This operation creates a tensor of `shape` and `dtype`.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [SameVariadicOperandSize]> {
|
||||
let summary = "Eases the porting of code that uses tf.nn.embedding_lookup().";
|
||||
|
||||
let description = [{
|
||||
sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond
|
||||
to the ith feature. table_ids[i] indicates which embedding table to look up ith
|
||||
feature.
|
||||
|
||||
The tensors at corresponding positions in two of the input lists,
|
||||
embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1
|
||||
with dim_size() equal to the total number of lookups into the table described by
|
||||
the corresponding feature.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_I32OrI64Tensor>:$sample_splits,
|
||||
Variadic<TF_I32OrI64Tensor>:$embedding_indices,
|
||||
Variadic<TF_F32OrF64Tensor>:$aggregation_weights,
|
||||
TF_StrTensor:$mode_override,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "-1">:$device_ordinal,
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$combiners,
|
||||
I64ArrayAttr:$table_ids,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$max_sequence_lengths
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr T1 = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr T2 = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [SameVariadicOperandSize]> {
|
||||
let summary = [{
|
||||
Eases the porting of code that uses tf.nn.embedding_lookup_sparse().
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond
|
||||
to the ith feature. table_ids[i] indicates which embedding table to look up ith
|
||||
feature.
|
||||
|
||||
The tensors at corresponding positions in the three input lists (sample_indices,
|
||||
embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1
|
||||
with dim_size() equal to the total number of lookups into the table described by
|
||||
the corresponding feature.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_I32OrI64Tensor>:$sample_indices,
|
||||
Variadic<TF_I32OrI64Tensor>:$embedding_indices,
|
||||
Variadic<TF_F32OrF64Tensor>:$aggregation_weights,
|
||||
TF_StrTensor:$mode_override,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "-1">:$device_ordinal,
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$combiners,
|
||||
I64ArrayAttr:$table_ids,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$max_sequence_lengths
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr T1 = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr T2 = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> {
|
||||
let summary = "Returns the truth value of (x == y) element-wise.";
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user