Add legalization for TPUEmbeddingActivationsOp.
PiperOrigin-RevId: 334205860 Change-Id: I5511a301283993a593ed1ffc75e19e84c85cddb5
This commit is contained in:
parent
c8ab7af5df
commit
44a568d101
@ -11832,6 +11832,30 @@ For internal use only.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TPUEmbeddingActivationsOp : TF_Op<"TPUEmbeddingActivations", [NoSideEffect]> {
|
||||
let summary = "An op enabling differentiation of TPU Embeddings.";
|
||||
|
||||
let description = [{
|
||||
This op simply returns its first input, which is assumed to have been sliced
|
||||
from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of
|
||||
this op, and its first argument being a trainable Variable, enables automatic
|
||||
differentiation of graphs containing embeddings via the TPU Embedding Python
|
||||
libraries.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Float32Tensor:$embedding_variable,
|
||||
TF_Float32Tensor:$sliced_activations,
|
||||
|
||||
Confined<I64Attr, [IntMinValue<0>]>:$table_id,
|
||||
Confined<I64Attr, [IntMinValue<0>]>:$lookup_id
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Float32Tensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TPUExecuteOp : TF_Op<"TPUExecute", []> {
|
||||
let summary = "Op that loads and executes a TPU program on a TPU device.";
|
||||
|
||||
|
@ -227,6 +227,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::StatelessTruncatedNormalOp>(),
|
||||
TypeID::get<TF::SubOp>(),
|
||||
TypeID::get<TF::TanOp>(),
|
||||
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
||||
TypeID::get<TF::TransposeOp>(),
|
||||
TypeID::get<TF::TruncateDivOp>(),
|
||||
TypeID::get<TF::TruncatedNormalOp>(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user