Replace const llvm::SmallVector<>& with llvm::ArrayRef and const std::string& with llvm::StringRef in TPUExtractOutsideCompilation. (NFC)
PiperOrigin-RevId: 316748196 Change-Id: Icdfcaa5a808ae69e5a6286d5bd7c6a988dbbe616
This commit is contained in:
parent
a71c78bcf9
commit
4db7ec5201
|
@ -73,9 +73,8 @@ LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Moves `cluster_ops` to associated `launch_op` body.
|
// Moves `cluster_ops` to associated `launch_op` body.
|
||||||
void MoveOutsideClusterOpsToLaunchOp(
|
void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,
|
||||||
tf_device::LaunchOp launch_op,
|
llvm::ArrayRef<Operation*> cluster_ops) {
|
||||||
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
|
||||||
MLIRContext* context = launch_op.getContext();
|
MLIRContext* context = launch_op.getContext();
|
||||||
Operation* terminator = launch_op.GetBody().getTerminator();
|
Operation* terminator = launch_op.GetBody().getTerminator();
|
||||||
|
|
||||||
|
@ -123,7 +122,7 @@ void PropagateParallelExecuteReturnToReplicate(
|
||||||
|
|
||||||
// Extracts all externally provided operands of `cluster_ops`.
|
// Extracts all externally provided operands of `cluster_ops`.
|
||||||
llvm::SmallSetVector<Value, 4> GetExternalOperands(
|
llvm::SmallSetVector<Value, 4> GetExternalOperands(
|
||||||
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
llvm::ArrayRef<Operation*> cluster_ops) {
|
||||||
llvm::SmallSetVector<Value, 4> external_values;
|
llvm::SmallSetVector<Value, 4> external_values;
|
||||||
|
|
||||||
for (Operation* op : cluster_ops) {
|
for (Operation* op : cluster_ops) {
|
||||||
|
@ -143,7 +142,7 @@ llvm::SmallSetVector<Value, 4> GetExternalOperands(
|
||||||
|
|
||||||
// Extracts all externally used outputs of `cluster_ops`.
|
// Extracts all externally used outputs of `cluster_ops`.
|
||||||
llvm::SmallVector<Value, 4> GetExternalOutputs(
|
llvm::SmallVector<Value, 4> GetExternalOutputs(
|
||||||
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
llvm::ArrayRef<Operation*> cluster_ops) {
|
||||||
llvm::SmallSetVector<Value, 4> external_outputs;
|
llvm::SmallSetVector<Value, 4> external_outputs;
|
||||||
|
|
||||||
for (Operation* op : cluster_ops) {
|
for (Operation* op : cluster_ops) {
|
||||||
|
@ -166,7 +165,7 @@ llvm::SmallVector<Value, 4> GetExternalOutputs(
|
||||||
// as an operand. If there are no external_inputs, set insertion point to first
|
// as an operand. If there are no external_inputs, set insertion point to first
|
||||||
// cluster_op.
|
// cluster_op.
|
||||||
void SetHostComputeInsertion(
|
void SetHostComputeInsertion(
|
||||||
OpBuilder* builder, const llvm::SmallVector<Operation*, 8>& cluster_ops,
|
OpBuilder* builder, llvm::ArrayRef<Operation*> cluster_ops,
|
||||||
const llvm::SmallSetVector<Value, 4>& external_inputs) {
|
const llvm::SmallSetVector<Value, 4>& external_inputs) {
|
||||||
if (external_inputs.empty()) builder->setInsertionPoint(cluster_ops.front());
|
if (external_inputs.empty()) builder->setInsertionPoint(cluster_ops.front());
|
||||||
for (const auto& cluster_op : cluster_ops) {
|
for (const auto& cluster_op : cluster_ops) {
|
||||||
|
@ -183,9 +182,9 @@ void SetHostComputeInsertion(
|
||||||
// using `communication_key`.
|
// using `communication_key`.
|
||||||
TF::_HostComputeMlirOp CreateHostCompute(
|
TF::_HostComputeMlirOp CreateHostCompute(
|
||||||
OpBuilder* builder, tf_device::ClusterOp tpu_cluster,
|
OpBuilder* builder, tf_device::ClusterOp tpu_cluster,
|
||||||
const llvm::SmallVector<Operation*, 8>& cluster_ops,
|
llvm::ArrayRef<Operation*> cluster_ops,
|
||||||
const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
|
const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
|
||||||
const std::string& communication_key) {
|
llvm::StringRef communication_key) {
|
||||||
llvm::SmallVector<Type, 4> device_output_types;
|
llvm::SmallVector<Type, 4> device_output_types;
|
||||||
for (const auto& output : outputs)
|
for (const auto& output : outputs)
|
||||||
device_output_types.push_back(output.getType());
|
device_output_types.push_back(output.getType());
|
||||||
|
@ -201,10 +200,9 @@ TF::_HostComputeMlirOp CreateHostCompute(
|
||||||
|
|
||||||
void MoveOutsideCompiledOps(
|
void MoveOutsideCompiledOps(
|
||||||
tf_device::ClusterOp tpu_cluster, llvm::StringRef outside_cluster_name,
|
tf_device::ClusterOp tpu_cluster, llvm::StringRef outside_cluster_name,
|
||||||
tf_device::LaunchOp host_launch_op,
|
tf_device::LaunchOp host_launch_op, llvm::ArrayRef<Operation*> cluster_ops,
|
||||||
const llvm::SmallVector<Operation*, 8>& cluster_ops,
|
|
||||||
const llvm::SmallSetVector<Value, 4>& external_inputs,
|
const llvm::SmallSetVector<Value, 4>& external_inputs,
|
||||||
const llvm::SmallVector<Value, 4>& external_outputs) {
|
llvm::ArrayRef<Value> external_outputs) {
|
||||||
if (external_inputs.empty() && external_outputs.empty()) {
|
if (external_inputs.empty() && external_outputs.empty()) {
|
||||||
MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops);
|
MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops);
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Reference in New Issue