diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 41ace43a483..4615064dc5c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -2062,4 +2062,28 @@ An op that groups a list of partitioned inputs together. This op TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } +def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> { + let summary = [{ +An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned + }]; + + let description = [{ +outputs outside the XLA computation. + }]; + + let arguments = (ins + TF_Tensor:$inputs, + + DefaultValuedAttr:$partition_dim, + OptionalAttr:$_XlaSharding + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>; +} + #endif // TF_OPS