Register TPUPartitionedOutput in TF MLIR ODS.
PiperOrigin-RevId: 348503783 Change-Id: Ib918eb5657e2eb530082355bc46d08d0c7f0e131
This commit is contained in:
parent
6ee022991d
commit
74cda1cdde
@ -2062,4 +2062,28 @@ An op that groups a list of partitioned inputs together. This op
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
outputs outside the XLA computation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$inputs,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "0">:$partition_dim,
|
||||
OptionalAttr<StrAttr>:$_XlaSharding
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>;
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
Loading…
x
Reference in New Issue
Block a user