Generate MLIR ops for TPU Host/Device communication for outside compilation.
These ops are needed for communicating dependencies(data or control flow) between TPU device calculations and outside compiled computations run on host. PiperOrigin-RevId: 311580827 Change-Id: Ia82623ae2a3535b829691952063724cfaedf22bb
This commit is contained in:
parent
9ea01fb495
commit
ba43780830
tensorflow
@ -10329,6 +10329,33 @@ https://www.tensorflow.org/xla/operation_semantics#gather
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> {
|
||||
let summary = [{
|
||||
A pseudo-op to represent host-side computation in an XLA program.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
|
||||
StrArrayAttr:$ancestors,
|
||||
TF_ShapeAttrArray:$shapes,
|
||||
SymbolRefAttr:$shape_inference_graph,
|
||||
StrAttr:$key,
|
||||
DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns,
|
||||
DefaultValuedAttr<I64Attr, "0">:$tpu_core
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$outputs
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA Sort operator, documented at";
|
||||
|
||||
@ -10377,6 +10404,24 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> {
|
||||
let summary = "An op to receive a tensor from the host.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ShapeAttr:$shape,
|
||||
StrAttr:$key
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr Toutput = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA Reduce operator, documented at";
|
||||
|
||||
@ -10441,6 +10486,23 @@ i=0...N-1.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> {
|
||||
let summary = "An op to send a tensor to the host.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
|
||||
StrAttr:$key
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the eigen decomposition of a batch of self-adjoint matrices
|
||||
@ -10582,3 +10644,44 @@ used to look up the program in the compilation cache.
|
||||
TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>;
|
||||
TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> {
|
||||
let summary = [{
|
||||
A placeholder op to receive values from a running XLA computation.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_StrTensor:$dynamic_key,
|
||||
|
||||
StrAttr:$key,
|
||||
I64Attr:$device_ordinal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$outputs
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> {
|
||||
let summary = "A placeholder op to send values to a running XLA computation.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
TF_StrTensor:$dynamic_key,
|
||||
|
||||
StrAttr:$key,
|
||||
I64Attr:$device_ordinal
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
|
@ -28,8 +28,7 @@ REGISTER_OP("_XlaSendFromHost")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(::tensorflow::shape_inference::NoOutputs)
|
||||
.Doc(R"doc(
|
||||
A placeholder op for multiple values that will be sent from TensorFlow to a
|
||||
running XLA computation.
|
||||
A placeholder op to send values to a running XLA computation.
|
||||
|
||||
inputs: A list of tensors that will be sent to the XLA computation.
|
||||
dynamic_key: The key sent at runtime by the compile node to identify which
|
||||
@ -49,8 +48,7 @@ REGISTER_OP("_XlaRecvAtHost")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
A placeholder op for multiple values that will be sent to TensorFlow from a
|
||||
running XLA computation.
|
||||
A placeholder op to receive values from a running XLA computation.
|
||||
|
||||
dynamic_key: The key sent at runtime by the compile node to identify which
|
||||
execution the transfer corresponds to.
|
||||
|
Loading…
Reference in New Issue
Block a user