Update TPUExtractHeadTailOutsideCompilation to use side effect analysis to determine if ops can be head or tail extracted.
There were cases where side effecting ops with no operands were head extracted when they should have been tail extracted (e.g. tf.WriteSummary followed by tf.FlushSummaryWriter). Other cases may include lifting ops to before or after the cluster when there are other side effecting ops before or after such ops that cannot be lifted. PiperOrigin-RevId: 329020980 Change-Id: Ib31d6fe9373d985fa645901fa1f91b284662e54c
This commit is contained in:
parent
2fd6b56a0c
commit
1255008e06
@ -173,7 +173,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @tail_single_outside_compiled_op() {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.NoOp"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
@ -190,7 +190,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
"tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
"tf.NoOp"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
@ -200,7 +200,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @tail_single_outside_compiled_op_user() -> tensor<i32> {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.NoOp"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
@ -217,7 +217,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%cluster = "tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.C"() : () -> ()
|
||||
"tf.NoOp"() : () -> ()
|
||||
tf_device.return %b : tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
|
||||
// CHECK: return %[[LAUNCH_OUT]]
|
||||
@ -262,7 +262,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%b = "tf.B"() : () -> tensor<i32>
|
||||
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
|
||||
// CHECK-NEXT: %[[E_OUT:.*]] = "tf.E"
|
||||
// CHECK-NEXT: %[[E_OUT:.*]] = "tf.Const"
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
@ -279,7 +279,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%cluster:5 = "tf_device.cluster"() ( {
|
||||
%c = "tf.C"() : () -> tensor<i32>
|
||||
%d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%e = "tf.E"() : () -> tensor<i32>
|
||||
%e = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %a, %b, %c, %d, %e : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
// CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1
|
||||
@ -320,14 +320,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor<i32>) {
|
||||
// CHECK-NOT: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%a = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.C"(%b) : (tensor<i32>) -> ()
|
||||
%c = "tf.Identity"(%b) : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
@ -379,7 +379,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]])
|
||||
// CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
|
||||
// CHECK-NEXT: "tf.IdentityN"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
@ -399,11 +399,72 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%b = "tf.B"() : () -> tensor<i32>
|
||||
%c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%e = "tf.E"(%c, %a) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%e:2 = "tf.IdentityN"(%c, %a) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @side_effect_middle
|
||||
func @side_effect_middle() {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @side_effect_head_no_operand
|
||||
func @side_effect_head_no_operand() {
|
||||
// CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.Const"
|
||||
// CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
|
||||
"tf_device.cluster"() ( {
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%c = "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> tensor<i32>
|
||||
"tf.D"(%c) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @side_effect_tail_no_operand
|
||||
func @side_effect_tail_no_operand() {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: "tf.Const"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
|
||||
// CHECK: "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]])
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
"tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
@ -34,6 +35,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
@ -118,7 +120,10 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
|
||||
// computation or other ops that can be extracted, and have no operands from
|
||||
// other ops in the TPU computation that cannot be extracted.
|
||||
llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
tf_device::ClusterOp cluster) {
|
||||
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
||||
cluster.getParentOfType<FuncOp>());
|
||||
Region* cluster_region = &cluster.body();
|
||||
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
||||
|
||||
@ -127,6 +132,15 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||
// An outside compiled op can be extracted if its operands are not from
|
||||
// other ops in the cluster that cannot be extracted.
|
||||
|
||||
// Check if the side effecting op right before this side effecting op, if
|
||||
// it is side effecting, can be head extracted. Because of op ordering due
|
||||
// to side effects, if this is not true, this op cannot be head extracted.
|
||||
auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
|
||||
if (!predecessors.empty() &&
|
||||
!head_outside_compiled_ops.contains(predecessors.back()))
|
||||
continue;
|
||||
|
||||
auto walk_result = cluster_op.walk([&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = GetOpOfValue(operand);
|
||||
@ -168,11 +182,11 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
// Extracts and move outside compiled ops that have no dependencies in the
|
||||
// cluster to before the cluster.
|
||||
mlir::LogicalResult LiftHeadOutsideCompiledOps(
|
||||
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
|
||||
tf_device::ClusterOp cluster, std::string* host_device,
|
||||
bool* cluster_updated) {
|
||||
OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
|
||||
std::string* host_device, bool* cluster_updated) {
|
||||
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
|
||||
FindOutsideCompiledOpsAtHead(cluster);
|
||||
FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
|
||||
if (head_outside_compiled_ops.empty()) return success();
|
||||
if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
|
||||
host_device)))
|
||||
@ -191,9 +205,12 @@ mlir::LogicalResult LiftHeadOutsideCompiledOps(
|
||||
// TPU computation or other ops that can be extracted, and have no results used
|
||||
// by other ops in the TPU computation that cannot be extracted.
|
||||
void FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
tf_device::ClusterOp cluster,
|
||||
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
|
||||
llvm::SmallVectorImpl<Value>* cluster_results) {
|
||||
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
||||
cluster.getParentOfType<FuncOp>());
|
||||
Region* cluster_region = &cluster.body();
|
||||
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
|
||||
Operation* terminator = cluster.GetBody().getTerminator();
|
||||
@ -205,6 +222,15 @@ void FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
for (Operation& cluster_op : cluster_ops) {
|
||||
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||
|
||||
// Check if the side effecting op right after this side effecting op, if
|
||||
// it is side effecting, can be tail extracted. Because of op ordering due
|
||||
// to side effects, if this is not true, this op cannot be tail extracted.
|
||||
auto successors = analysis.DirectControlSuccessors(
|
||||
&cluster_op, [&terminator](Operation* op) { return op != terminator; });
|
||||
if (!successors.empty() &&
|
||||
!tail_outside_compiled_ops_set.contains(successors.front()))
|
||||
continue;
|
||||
|
||||
llvm::SmallVector<int, 4> results_to_forward;
|
||||
bool can_be_extracted =
|
||||
llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
|
||||
@ -293,13 +319,14 @@ tf_device::ClusterOp UpdateClusterResults(
|
||||
// Extracts and move outside compiled ops that do not create dependencies in the
|
||||
// cluster to after the cluster.
|
||||
mlir::LogicalResult LiftTailOutsideCompiledOps(
|
||||
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
|
||||
std::string host_device, tf_device::ClusterOp* cluster,
|
||||
bool* cluster_updated) {
|
||||
OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
const mlir::TF::RuntimeDevices& devices, std::string host_device,
|
||||
tf_device::ClusterOp* cluster, bool* cluster_updated) {
|
||||
llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
|
||||
llvm::SmallVector<Value, 4> cluster_results;
|
||||
FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
*cluster, &tail_outside_compiled_ops, &cluster_results);
|
||||
FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
|
||||
&tail_outside_compiled_ops,
|
||||
&cluster_results);
|
||||
if (tail_outside_compiled_ops.empty()) return success();
|
||||
|
||||
if (host_device.empty())
|
||||
@ -365,6 +392,7 @@ struct TPUExtractHeadTailOutsideCompilation
|
||||
};
|
||||
|
||||
void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
||||
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
|
||||
// Get runtime devices information from the closest parent module.
|
||||
auto module = getOperation();
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
@ -379,10 +407,12 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
||||
for (tf_device::ClusterOp cluster : clusters) {
|
||||
std::string host_device;
|
||||
bool cluster_updated = false;
|
||||
if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster,
|
||||
&host_device, &cluster_updated)) ||
|
||||
failed(LiftTailOutsideCompiledOps(&builder, devices, host_device,
|
||||
&cluster, &cluster_updated)))
|
||||
if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
|
||||
devices, cluster, &host_device,
|
||||
&cluster_updated)) ||
|
||||
failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
|
||||
devices, host_device, &cluster,
|
||||
&cluster_updated)))
|
||||
return signalPassFailure();
|
||||
if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user