diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 64ea0732e8c..aa1601c4032 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -10329,6 +10329,33 @@ https://www.tensorflow.org/xla/operation_semantics#gather TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> { + let summary = [{ +A pseudo-op to represent host-side computation in an XLA program. + }]; + + let description = [{ + }]; + + let arguments = (ins + Variadic<TF_Tensor>:$inputs, + + StrArrayAttr:$ancestors, + TF_ShapeAttrArray:$shapes, + SymbolRefAttr:$shape_inference_graph, + StrAttr:$key, + DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns, + DefaultValuedAttr<I64Attr, "0">:$tpu_core + ); + + let results = (outs + Variadic<TF_Tensor>:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> { let summary = "Wraps the XLA Sort operator, documented at"; @@ -10377,6 +10404,24 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> { + let summary = "An op to receive a tensor from the host."; + + let description = [{ + }]; + + let arguments = (ins + TF_ShapeAttr:$shape, + StrAttr:$key + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr Toutput = TF_DerivedResultTypeAttr<0>; +} + def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> { let summary = "Wraps the XLA Reduce operator, documented at"; @@ -10441,6 +10486,23 @@ i=0...N-1. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> { + let summary = "An op to send a tensor to the host."; + + let description = [{ + }]; + + let arguments = (ins + TF_Tensor:$input, + + StrAttr:$key + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> { let summary = [{ Computes the eigen decomposition of a batch of self-adjoint matrices @@ -10582,3 +10644,44 @@ used to look up the program in the compilation cache. TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>; TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>; } + +def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> { + let summary = [{ +A placeholder op to receive values from a running XLA computation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_StrTensor:$dynamic_key, + + StrAttr:$key, + I64Attr:$device_ordinal + ); + + let results = (outs + Variadic<TF_Tensor>:$outputs + ); + + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + +def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> { + let summary = "A placeholder op to send values to a running XLA computation."; + + let description = [{ + }]; + + let arguments = (ins + Variadic<TF_Tensor>:$inputs, + TF_StrTensor:$dynamic_key, + + StrAttr:$key, + I64Attr:$device_ordinal + ); + + let results = (outs); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; +} diff --git a/tensorflow/core/ops/tpu_host_compute_ops.cc b/tensorflow/core/ops/tpu_host_compute_ops.cc index 48aeb81ac13..753cc0015d9 100644 --- a/tensorflow/core/ops/tpu_host_compute_ops.cc +++ b/tensorflow/core/ops/tpu_host_compute_ops.cc @@ -28,8 +28,7 @@ REGISTER_OP("_XlaSendFromHost") .SetIsStateful() .SetShapeFn(::tensorflow::shape_inference::NoOutputs) .Doc(R"doc( -A placeholder op for multiple values that will be sent from TensorFlow to a -running XLA computation. +A placeholder op to send values to a running XLA computation. inputs: A list of tensors that will be sent to the XLA computation. dynamic_key: The key sent at runtime by the compile node to identify which @@ -49,8 +48,7 @@ REGISTER_OP("_XlaRecvAtHost") .SetIsStateful() .SetShapeFn(::tensorflow::shape_inference::UnknownShape) .Doc(R"doc( -A placeholder op for multiple values that will be sent to TensorFlow from a -running XLA computation. +A placeholder op to receive values from a running XLA computation. dynamic_key: The key sent at runtime by the compile node to identify which execution the transfer corresponds to.