Add pass that propagates TPU devices from operands to uses.
This pass finds values placed on TPU and propagates them to certain users. Ops include certain tf_executor dialect ops and more trivial tf dialect ops (Identity, IdentityN, Shape). PiperOrigin-RevId: 338778252 Change-Id: Idbea53ea5a52df036cdd249f8d894e052a81a6a5
This commit is contained in:
parent
27fefb01cf
commit
d39bcbaefb
@ -884,6 +884,7 @@ cc_library(
|
||||
"transforms/tpu_cluster_cleanup_attributes.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_colocate_composite_resource_ops.cc",
|
||||
"transforms/tpu_device_propagation.cc",
|
||||
"transforms/tpu_dynamic_layout_pass.cc",
|
||||
"transforms/tpu_dynamic_padding_mapper.cc",
|
||||
"transforms/tpu_extract_head_tail_outside_compilation.cc",
|
||||
|
@ -0,0 +1,383 @@
|
||||
// RUN: tf-opt %s -tf-tpu-device-propagation | FileCheck %s
|
||||
|
||||
// Tests function passthrough values.
|
||||
|
||||
// CHECK-LABEL: func @testArgToRet
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testArgToRet(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
|
||||
%0 = tf_executor.graph {
|
||||
tf_executor.fetch %arg0 : tensor<i64>
|
||||
}
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// Tests supported ops.
|
||||
|
||||
// CHECK-LABEL: func @testIdentityOp
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testIdentityOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %1#0 : tensor<i64>
|
||||
}
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testIdentityNOp
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testIdentityNOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i64>, tensor<i32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
// CHECK: tf.IdentityN
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor<i64>, tensor<i32>) -> (tensor<i64>, tensor<i32>)
|
||||
tf_executor.fetch %1#0, %1#1 : tensor<i64>, tensor<i32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<i64>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testShapeOp
|
||||
// CHECK-SAME: ({{%.+}}: tensor<*xi64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<?xi64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testShapeOp(%arg0: tensor<*xi64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<?xi64> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.Shape
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island wraps "tf.Shape"(%arg0) : (tensor<*xi64>) -> tensor<?xi64>
|
||||
tf_executor.fetch %1#0 : tensor<?xi64>
|
||||
}
|
||||
return %0 : tensor<?xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testEnterOp
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testEnterOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf_executor.Enter
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.Enter %arg0 frame "frame" : tensor<i64>
|
||||
tf_executor.fetch %1#0 : tensor<i64>
|
||||
}
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testExitOp
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testExitOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf_executor.Exit
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.Exit %arg0 : tensor<i64>
|
||||
tf_executor.fetch %1#0 : tensor<i64>
|
||||
}
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testMergeOp
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testMergeOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i64>, tensor<i32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
// CHECK: tf_executor.Merge
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:3 = tf_executor.Merge %arg0, %arg1 : tensor<i64>
|
||||
tf_executor.fetch %1#0, %1#1 : tensor<i64>, tensor<i32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<i64>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testSwitchOp
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i1> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testSwitchOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i1> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.Switch
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%0:3 = tf_executor.Switch %arg0, %arg1 : tensor<i64> {T = "tfdtype$DT_INT64"}
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%2:2 = tf_executor.island wraps "tf.Identity"(%0#1) : (tensor<i64>) -> tensor<i64>
|
||||
%3 = tf_executor.ControlTrigger %1#1, %2#1
|
||||
tf_executor.fetch %3 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tests unsupported op does not have TPU device propagated.
|
||||
|
||||
// CHECK-LABEL: func @testUnsupportedOp
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> tensor<i64>
|
||||
func @testUnsupportedOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.UnsupportedOp
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island wraps "tf.UnsupportedOp"(%arg0) : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %1#0 : tensor<i64>
|
||||
}
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// Tests empty devices are overwritten.
|
||||
|
||||
// CHECK-LABEL: func @testEmptyDeviceOverwritten
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testEmptyDeviceOverwritten(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i64> {tf.device = ""}) {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%arg0) {device = ""} : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %1#0 : tensor<i64>
|
||||
}
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// Tests only devices are propagated when all operands are on the same TPU
|
||||
// device.
|
||||
|
||||
// CHECK-LABEL: func @testOperandsNoDevice
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i32>)
|
||||
// CHECK-SAME: -> (tensor<i64>, tensor<i32>)
|
||||
func @testOperandsNoDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i32>) -> (tensor<i64>, tensor<i32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
// CHECK: tf.IdentityN
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor<i64>, tensor<i32>) -> (tensor<i64>, tensor<i32>)
|
||||
tf_executor.fetch %1#0, %1#1 : tensor<i64>, tensor<i32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<i64>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testOperandsDifferentDevice
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"})
|
||||
// CHECK-SAME: -> (tensor<i64>, tensor<i32>)
|
||||
func @testOperandsDifferentDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}) -> (tensor<i64>, tensor<i32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
// CHECK: tf.IdentityN
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:1"
|
||||
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor<i64>, tensor<i32>) -> (tensor<i64>, tensor<i32>)
|
||||
tf_executor.fetch %1#0, %1#1 : tensor<i64>, tensor<i32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<i64>, tensor<i32>
|
||||
}
|
||||
|
||||
// Tests op with operand on different device does not have its device
|
||||
// overwritten.
|
||||
|
||||
// CHECK-LABEL: func @testDifferentOperandAndOpDevice
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
func @testDifferentOperandAndOpDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1"
|
||||
%0:2 = tf_executor.island wraps "tf.Identity"(%arg0) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %0#1 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testDifferentOperandAndResultDevice
|
||||
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
|
||||
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"})
|
||||
func @testDifferentOperandAndResultDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}) {
|
||||
%0 = tf_executor.graph {
|
||||
tf_executor.fetch %arg0 : tensor<i64>
|
||||
}
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// Tests non TPU devices are not propagated.
|
||||
|
||||
// CHECK-LABEL: func @testNonTPUDevice
|
||||
func @testNonTPUDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
%0:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %0#1 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tests control dependencies are ignored for propagating devices.
|
||||
|
||||
// CHECK-LABEL: func @testControlDependenciesIgnored
|
||||
func @testControlDependenciesIgnored(%arg0: tensor<i64>) {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island(%0#1) wraps "tf.Identity"(%arg0) : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %1#1 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testControlDependenciesMismatchedDevices
|
||||
func @testControlDependenciesMismatchedDevices(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:1", value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island(%0#1) wraps "tf.Identity"(%arg0) : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %1#1 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tests LoopCond -> Switch where LoopCond has a different device is ignored.
|
||||
|
||||
// CHECK-LABEL: func @testLoopCondSwitchLinkDifferentDevice
|
||||
func @testLoopCondSwitchLinkDifferentDevice() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
%1:2 = tf_executor.LoopCond %0#0 : (tensor<i1>) -> (tensor<i1>, !tf_executor.control) {}
|
||||
%2:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: tf_executor.Switch
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%3:3 = tf_executor.Switch %2#0, %1#0 : tensor<i64> {T = "tfdtype$DT_INT64"}
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%4:2 = tf_executor.island wraps "tf.Identity"(%3#0) : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%5:2 = tf_executor.island wraps "tf.Identity"(%3#1) : (tensor<i64>) -> tensor<i64>
|
||||
%6 = tf_executor.ControlTrigger %4#1, %5#1
|
||||
tf_executor.fetch %6 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tests tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink has a
|
||||
// device when an intermediate op in its loop has a device.
|
||||
|
||||
// CHECK-LABEL: func @testNextIterationNoDevice
|
||||
func @testNextIterationNoDevice() {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.NextIteration.Source
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%0:3 = tf_executor.NextIteration.Source : tensor<i64> {T = "tfdtype$DT_INT64"}
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK: tf.IdentityN
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%2:2 = tf_executor.island wraps "tf.IdentityN"(%1#0) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK: tf_executor.NextIteration.Sink
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
tf_executor.NextIteration.Sink [%0#1] %2#0 : tensor<i64> {T = "tfdtype$DT_INT64"}
|
||||
tf_executor.fetch %0#2 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tests tf_executor.NextIteration with mismatched devices does not propagate
|
||||
// either device.
|
||||
|
||||
// CHECK-LABEL: func @testNextIterationMismatchedDevices
|
||||
func @testNextIterationMismatchedDevices() {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.NextIteration.Source
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1"
|
||||
%0:3 = tf_executor.NextIteration.Source : tensor<i64> {device = "/job:localhost/replica:0/task:0/device:TPU:1", T = "tfdtype$DT_INT64"}
|
||||
// CHECK: "tf.Identity"({{.+}}) :
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK: tf_executor.NextIteration.Sink
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor<i64> {device = "/job:localhost/replica:0/task:0/device:TPU:0", T = "tfdtype$DT_INT64"}
|
||||
tf_executor.fetch %0#2 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testNextIterationMissingSourceDevice
|
||||
func @testNextIterationMissingSourceDevice() {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.NextIteration.Source
|
||||
%0:3 = tf_executor.NextIteration.Source : tensor<i64> {T = "tfdtype$DT_INT64"}
|
||||
// CHECK: "tf.Identity"({{.+}}) :
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK: tf_executor.NextIteration.Sink
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor<i64> {device = "/job:localhost/replica:0/task:0/device:TPU:0", T = "tfdtype$DT_INT64"}
|
||||
tf_executor.fetch %0#2 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testNextIterationMissingSinkDevice
|
||||
func @testNextIterationMissingSinkDevice() {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.NextIteration.Source
|
||||
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1"
|
||||
%0:3 = tf_executor.NextIteration.Source : tensor<i64> {device = "/job:localhost/replica:0/task:0/device:TPU:1", T = "tfdtype$DT_INT64"}
|
||||
// CHECK: "tf.Identity"({{.+}}) :
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK: tf_executor.NextIteration.Sink
|
||||
tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor<i64> {T = "tfdtype$DT_INT64"}
|
||||
tf_executor.fetch %0#2 : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tests unsupported functions are not modified.
|
||||
|
||||
// CHECK-LABEL: func @testMultipleBlockFunc
|
||||
func @testMultipleBlockFunc() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %1#1 : !tf_executor.control
|
||||
}
|
||||
br ^bb1
|
||||
^bb1:
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testMultipleGraphs
|
||||
func @testMultipleGraphs() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %1#1 : !tf_executor.control
|
||||
}
|
||||
tf_executor.graph {
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testNoGraph
|
||||
func @testNoGraph() -> tensor<i64> {
|
||||
%0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%1 = "tf.Identity"(%0) : (tensor<i64>) -> tensor<i64>
|
||||
return %1 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testMismatchedGraphResults
|
||||
func @testMismatchedGraphResults() {
|
||||
%0 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: tf.Identity
|
||||
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
|
||||
%2:2 = tf_executor.island wraps "tf.Identity"(%1#0) : (tensor<i64>) -> tensor<i64>
|
||||
tf_executor.fetch %2#0 : tensor<i64>
|
||||
}
|
||||
return
|
||||
}
|
@ -358,6 +358,9 @@ CreateTPUUpdateEmbeddingEnqueueOpInputsPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateTPUExtractOutsideCompilationPass();
|
||||
|
||||
// Creates a pass that propagates TPU devices to users.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDevicePropagationPass();
|
||||
|
||||
// Populates the supplied passmanager with the passes required to run the
|
||||
// bridge.
|
||||
void CreateTPUBridgePipeline(OpPassManager& pm);
|
||||
|
@ -0,0 +1,253 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kFuncDeviceAttr[] = "tf.device";
|
||||
|
||||
// Checks if a function only contains a tf_executor.graph.
|
||||
bool IsSupportedGraph(FuncOp func) {
|
||||
if (!llvm::hasSingleElement(func)) return false;
|
||||
|
||||
Block& block = func.front();
|
||||
if (!llvm::hasSingleElement(block.without_terminator())) return false;
|
||||
|
||||
auto graph = llvm::dyn_cast<tf_executor::GraphOp>(block.front());
|
||||
if (!graph) return false;
|
||||
|
||||
Operation* terminator = block.getTerminator();
|
||||
if (graph.getNumResults() != terminator->getNumOperands()) return false;
|
||||
for (auto result : llvm::zip(graph.results(), terminator->getOperands()))
|
||||
if (std::get<0>(result) != std::get<1>(result)) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks if an operation of the tf_executor dialect can have TPU devices
|
||||
// propagated through.
|
||||
bool IsSupportedExecutorOp(Operation& op) {
|
||||
auto ops_have_same_device = [](Operation* lhs, Operation* rhs) {
|
||||
auto lhs_device_attr = lhs->getAttrOfType<StringAttr>(kDeviceAttr);
|
||||
auto rhs_device_attr = rhs->getAttrOfType<StringAttr>(kDeviceAttr);
|
||||
return (!lhs_device_attr && !rhs_device_attr) ||
|
||||
(lhs_device_attr && rhs_device_attr &&
|
||||
lhs_device_attr.getValue() == rhs_device_attr.getValue());
|
||||
};
|
||||
|
||||
// Check if tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink
|
||||
// pair has matching devices or no devices.
|
||||
if (auto source = llvm::dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
|
||||
return ops_have_same_device(source, source.GetSink());
|
||||
} else if (auto sink = llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
|
||||
return ops_have_same_device(sink.GetSource(), sink);
|
||||
}
|
||||
|
||||
return llvm::isa<tf_executor::EnterOp, tf_executor::ExitOp,
|
||||
tf_executor::IslandOp, tf_executor::MergeOp,
|
||||
tf_executor::SwitchOp>(op);
|
||||
}
|
||||
|
||||
// Assigns all data results to a specified device.
|
||||
void PopulateDeviceForOpResults(
|
||||
Operation& op, llvm::StringRef device,
|
||||
llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
|
||||
Operation* op_to_update = &op;
|
||||
// Use tf_executor.island op if present as non v1 control flow op results are
|
||||
// forwarded by a parent tf_executor.island op.
|
||||
if (llvm::isa<tf_executor::IslandOp>(op_to_update->getParentOp()))
|
||||
op_to_update = op_to_update->getParentOp();
|
||||
|
||||
for (Value result : op_to_update->getResults()) {
|
||||
if (result.getType().isa<tf_executor::TokenType>()) continue;
|
||||
if (result.getType().isa<tf_executor::ControlType>()) break;
|
||||
|
||||
value_to_device.insert({result, device});
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if an operation can have TPU devices propagated through.
|
||||
bool IsSupportedOpToSetDevice(Operation& op) {
|
||||
return IsSupportedExecutorOp(op) ||
|
||||
isa<TF::IdentityOp, TF::IdentityNOp, TF::ShapeOp>(op);
|
||||
}
|
||||
|
||||
// Finds nonconflicting TPU device for an operation from its operands. If an
|
||||
// operand has no device or a non TPU device, or if there are conflicting
|
||||
// devices, and empty StringRef will be returned. Control dependencies,
|
||||
// NextIteration.Source -> NextIteration.Sink token dependencies, and
|
||||
// LoopCond -> Switch data dependencies are ignored.
|
||||
llvm::StringRef FindDeviceFromOperands(
|
||||
Operation& op,
|
||||
const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
|
||||
llvm::StringRef new_device;
|
||||
const bool is_switch = llvm::isa<tf_executor::SwitchOp>(op);
|
||||
for (Value operand : op.getOperands()) {
|
||||
if (operand.getType().isa<tf_executor::TokenType>()) continue;
|
||||
if (operand.getType().isa<tf_executor::ControlType>()) break;
|
||||
|
||||
if (is_switch &&
|
||||
llvm::isa_and_nonnull<tf_executor::LoopCondOp>(operand.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
auto it = value_to_device.find(operand);
|
||||
if (it == value_to_device.end()) return llvm::StringRef();
|
||||
|
||||
if (new_device.empty()) {
|
||||
new_device = it->getSecond();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (new_device != it->getSecond()) return llvm::StringRef();
|
||||
}
|
||||
|
||||
return new_device;
|
||||
}
|
||||
|
||||
// Propagates devices from function arguments.
|
||||
void PropagateDevicesFromArguments(
|
||||
FuncOp func, llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
|
||||
for (BlockArgument& arg : func.getArguments()) {
|
||||
auto arg_device_attr =
|
||||
func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), kFuncDeviceAttr);
|
||||
if (!arg_device_attr || arg_device_attr.getValue().empty() ||
|
||||
!tensorflow::IsTPUDevice(arg_device_attr.getValue()))
|
||||
continue;
|
||||
value_to_device.insert({arg, arg_device_attr.getValue()});
|
||||
}
|
||||
}
|
||||
|
||||
// Propagates devices from operation operands to results. Updating the device of
|
||||
// a tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink will result
|
||||
// in multiple passes over the tf_executor.graph to propagate devices in loops.
|
||||
void PropagateDevicesInGraph(
|
||||
tf_executor::GraphOp graph,
|
||||
llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
|
||||
auto ops = graph.GetBody().without_terminator();
|
||||
|
||||
bool updated_next_iteration = false;
|
||||
do {
|
||||
updated_next_iteration = false;
|
||||
for (Operation& op : ops) {
|
||||
if (!IsSupportedExecutorOp(op)) continue;
|
||||
|
||||
Operation* op_to_update = &op;
|
||||
// Unpack inner op of tf_executor.island.
|
||||
if (auto island_op =
|
||||
llvm::dyn_cast<tf_executor::IslandOp>(op_to_update)) {
|
||||
if (!island_op.WrapsSingleOp()) continue;
|
||||
op_to_update = &island_op.GetBody().front();
|
||||
}
|
||||
|
||||
// If op already has a TPU device set, simply propagate its device.
|
||||
auto device_attr = op_to_update->getAttrOfType<StringAttr>(kDeviceAttr);
|
||||
const bool has_device = device_attr && !device_attr.getValue().empty();
|
||||
if (has_device && tensorflow::IsTPUDevice(device_attr.getValue())) {
|
||||
PopulateDeviceForOpResults(*op_to_update, device_attr.getValue(),
|
||||
value_to_device);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Op has an unsupported device.
|
||||
if (has_device) continue;
|
||||
|
||||
if (!IsSupportedOpToSetDevice(*op_to_update)) continue;
|
||||
|
||||
llvm::StringRef new_device =
|
||||
FindDeviceFromOperands(*op_to_update, value_to_device);
|
||||
if (new_device.empty()) continue;
|
||||
|
||||
auto new_device_attr =
|
||||
mlir::StringAttr::get(new_device, op_to_update->getContext());
|
||||
op_to_update->setAttr(kDeviceAttr, new_device_attr);
|
||||
PopulateDeviceForOpResults(*op_to_update, new_device_attr.getValue(),
|
||||
value_to_device);
|
||||
|
||||
if (auto sink =
|
||||
llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
|
||||
auto source = sink.GetSource();
|
||||
source.setAttr(kDeviceAttr, new_device_attr);
|
||||
PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
|
||||
value_to_device);
|
||||
updated_next_iteration = true;
|
||||
}
|
||||
}
|
||||
} while (updated_next_iteration);
|
||||
}
|
||||
|
||||
// Propagates devices to function results.
|
||||
void PropagateDevicesToResults(
|
||||
FuncOp func, tf_executor::FetchOp fetch,
|
||||
const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
|
||||
for (OpOperand& operand : fetch.getOperation()->getOpOperands()) {
|
||||
if (operand.get().getType().isa<tf_executor::ControlType>()) break;
|
||||
auto it = value_to_device.find(operand.get());
|
||||
if (it != value_to_device.end()) {
|
||||
auto device_attr = func.getResultAttrOfType<StringAttr>(
|
||||
operand.getOperandNumber(), kFuncDeviceAttr);
|
||||
if (device_attr && !device_attr.getValue().empty()) continue;
|
||||
func.setResultAttr(operand.getOperandNumber(), kFuncDeviceAttr,
|
||||
StringAttr::get(it->getSecond(), func.getContext()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TPUDevicePropagation
|
||||
: public PassWrapper<TPUDevicePropagation, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void TPUDevicePropagation::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
if (!IsSupportedGraph(func)) return;
|
||||
|
||||
llvm::DenseMap<Value, llvm::StringRef> value_to_device;
|
||||
PropagateDevicesFromArguments(func, value_to_device);
|
||||
auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
|
||||
PropagateDevicesInGraph(graph, value_to_device);
|
||||
PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDevicePropagationPass() {
|
||||
return std::make_unique<TPUDevicePropagation>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUDevicePropagation> pass(
|
||||
"tf-tpu-device-propagation", "Propagates TPU devices from ops to users");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
Loading…
x
Reference in New Issue
Block a user