Add legalization for TPUEmbeddingActivationsOp.

PiperOrigin-RevId: 334205860
Change-Id: I5511a301283993a593ed1ffc75e19e84c85cddb5
This commit is contained in:
Pankaj Kanwar 2020-09-28 11:51:44 -07:00 committed by TensorFlower Gardener
parent c8ab7af5df
commit 44a568d101
2 changed files with 25 additions and 0 deletions

View File

@ -11832,6 +11832,30 @@ For internal use only.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TPUEmbeddingActivationsOp : TF_Op<"TPUEmbeddingActivations", [NoSideEffect]> {
let summary = "An op enabling differentiation of TPU Embeddings.";
let description = [{
This op simply returns its first input, which is assumed to have been sliced
from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of
this op, and its first argument being a trainable Variable, enables automatic
differentiation of graphs containing embeddings via the TPU Embedding Python
libraries.
}];
let arguments = (ins
TF_Float32Tensor:$embedding_variable,
TF_Float32Tensor:$sliced_activations,
Confined<I64Attr, [IntMinValue<0>]>:$table_id,
Confined<I64Attr, [IntMinValue<0>]>:$lookup_id
);
let results = (outs
TF_Float32Tensor:$output
);
}
def TF_TPUExecuteOp : TF_Op<"TPUExecute", []> {
let summary = "Op that loads and executes a TPU program on a TPU device.";

View File

@ -227,6 +227,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::StatelessTruncatedNormalOp>(),
TypeID::get<TF::SubOp>(),
TypeID::get<TF::TanOp>(),
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
TypeID::get<TF::TransposeOp>(),
TypeID::get<TF::TruncateDivOp>(),
TypeID::get<TF::TruncatedNormalOp>(),