Add pass that propagates TPU devices from operands to uses.

This pass finds values placed on TPU and propagates them to certain users. Ops include certain tf_executor dialect ops and more trivial tf dialect ops (Identity, IdentityN, Shape).

PiperOrigin-RevId: 338778252
Change-Id: Idbea53ea5a52df036cdd249f8d894e052a81a6a5
This commit is contained in:
Andy Ly 2020-10-23 18:06:58 -07:00 committed by TensorFlower Gardener
parent 27fefb01cf
commit d39bcbaefb
4 changed files with 640 additions and 0 deletions

View File

@ -884,6 +884,7 @@ cc_library(
"transforms/tpu_cluster_cleanup_attributes.cc",
"transforms/tpu_cluster_formation.cc",
"transforms/tpu_colocate_composite_resource_ops.cc",
"transforms/tpu_device_propagation.cc",
"transforms/tpu_dynamic_layout_pass.cc",
"transforms/tpu_dynamic_padding_mapper.cc",
"transforms/tpu_extract_head_tail_outside_compilation.cc",

View File

@ -0,0 +1,383 @@
// RUN: tf-opt %s -tf-tpu-device-propagation | FileCheck %s
// Tests function passthrough values.
// CHECK-LABEL: func @testArgToRet
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testArgToRet(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
%0 = tf_executor.graph {
tf_executor.fetch %arg0 : tensor<i64>
}
return %0 : tensor<i64>
}
// Tests supported ops.
// CHECK-LABEL: func @testIdentityOp
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testIdentityOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
%0 = tf_executor.graph {
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %1#0 : tensor<i64>
}
return %0 : tensor<i64>
}
// CHECK-LABEL: func @testIdentityNOp
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testIdentityNOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i64>, tensor<i32>) {
%0:2 = tf_executor.graph {
// CHECK: tf.IdentityN
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor<i64>, tensor<i32>) -> (tensor<i64>, tensor<i32>)
tf_executor.fetch %1#0, %1#1 : tensor<i64>, tensor<i32>
}
return %0#0, %0#1 : tensor<i64>, tensor<i32>
}
// CHECK-LABEL: func @testShapeOp
// CHECK-SAME: ({{%.+}}: tensor<*xi64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<?xi64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testShapeOp(%arg0: tensor<*xi64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<?xi64> {
%0 = tf_executor.graph {
// CHECK: tf.Shape
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island wraps "tf.Shape"(%arg0) : (tensor<*xi64>) -> tensor<?xi64>
tf_executor.fetch %1#0 : tensor<?xi64>
}
return %0 : tensor<?xi64>
}
// CHECK-LABEL: func @testEnterOp
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testEnterOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
%0 = tf_executor.graph {
// CHECK: tf_executor.Enter
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.Enter %arg0 frame "frame" : tensor<i64>
tf_executor.fetch %1#0 : tensor<i64>
}
return %0 : tensor<i64>
}
// CHECK-LABEL: func @testExitOp
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testExitOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
%0 = tf_executor.graph {
// CHECK: tf_executor.Exit
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.Exit %arg0 : tensor<i64>
tf_executor.fetch %1#0 : tensor<i64>
}
return %0 : tensor<i64>
}
// CHECK-LABEL: func @testMergeOp
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testMergeOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i64>, tensor<i32>) {
%0:2 = tf_executor.graph {
// CHECK: tf_executor.Merge
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:3 = tf_executor.Merge %arg0, %arg1 : tensor<i64>
tf_executor.fetch %1#0, %1#1 : tensor<i64>, tensor<i32>
}
return %0#0, %0#1 : tensor<i64>, tensor<i32>
}
// CHECK-LABEL: func @testSwitchOp
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i1> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testSwitchOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i1> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) {
tf_executor.graph {
// CHECK: tf_executor.Switch
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%0:3 = tf_executor.Switch %arg0, %arg1 : tensor<i64> {T = "tfdtype$DT_INT64"}
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%2:2 = tf_executor.island wraps "tf.Identity"(%0#1) : (tensor<i64>) -> tensor<i64>
%3 = tf_executor.ControlTrigger %1#1, %2#1
tf_executor.fetch %3 : !tf_executor.control
}
return
}
// Tests unsupported op does not have TPU device propagated.
// CHECK-LABEL: func @testUnsupportedOp
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> tensor<i64>
func @testUnsupportedOp(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i64> {
%0 = tf_executor.graph {
// CHECK: tf.UnsupportedOp
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island wraps "tf.UnsupportedOp"(%arg0) : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %1#0 : tensor<i64>
}
return %0 : tensor<i64>
}
// Tests empty devices are overwritten.
// CHECK-LABEL: func @testEmptyDeviceOverwritten
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testEmptyDeviceOverwritten(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i64> {tf.device = ""}) {
%0 = tf_executor.graph {
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island wraps "tf.Identity"(%arg0) {device = ""} : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %1#0 : tensor<i64>
}
return %0 : tensor<i64>
}
// Tests only devices are propagated when all operands are on the same TPU
// device.
// CHECK-LABEL: func @testOperandsNoDevice
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i32>)
// CHECK-SAME: -> (tensor<i64>, tensor<i32>)
func @testOperandsNoDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i32>) -> (tensor<i64>, tensor<i32>) {
%0:2 = tf_executor.graph {
// CHECK: tf.IdentityN
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor<i64>, tensor<i32>) -> (tensor<i64>, tensor<i32>)
tf_executor.fetch %1#0, %1#1 : tensor<i64>, tensor<i32>
}
return %0#0, %0#1 : tensor<i64>, tensor<i32>
}
// CHECK-LABEL: func @testOperandsDifferentDevice
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"})
// CHECK-SAME: -> (tensor<i64>, tensor<i32>)
func @testOperandsDifferentDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}) -> (tensor<i64>, tensor<i32>) {
%0:2 = tf_executor.graph {
// CHECK: tf.IdentityN
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:1"
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor<i64>, tensor<i32>) -> (tensor<i64>, tensor<i32>)
tf_executor.fetch %1#0, %1#1 : tensor<i64>, tensor<i32>
}
return %0#0, %0#1 : tensor<i64>, tensor<i32>
}
// Tests op with operand on different device does not have its device
// overwritten.
// CHECK-LABEL: func @testDifferentOperandAndOpDevice
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
func @testDifferentOperandAndOpDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) {
tf_executor.graph {
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1"
%0:2 = tf_executor.island wraps "tf.Identity"(%arg0) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %0#1 : !tf_executor.control
}
return
}
// CHECK-LABEL: func @testDifferentOperandAndResultDevice
// CHECK-SAME: ({{%.+}}: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"})
// CHECK-SAME: -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"})
func @testDifferentOperandAndResultDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}) {
%0 = tf_executor.graph {
tf_executor.fetch %arg0 : tensor<i64>
}
return %0 : tensor<i64>
}
// Tests non TPU devices are not propagated.
// CHECK-LABEL: func @testNonTPUDevice
func @testNonTPUDevice(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) {
tf_executor.graph {
// CHECK: tf.Identity
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
%0:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %0#1 : !tf_executor.control
}
return
}
// Tests control dependencies are ignored for propagating devices.
// CHECK-LABEL: func @testControlDependenciesIgnored
func @testControlDependenciesIgnored(%arg0: tensor<i64>) {
tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: tf.Identity
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island(%0#1) wraps "tf.Identity"(%arg0) : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %1#1 : !tf_executor.control
}
return
}
// CHECK-LABEL: func @testControlDependenciesMismatchedDevices
func @testControlDependenciesMismatchedDevices(%arg0: tensor<i64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) {
tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:1", value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island(%0#1) wraps "tf.Identity"(%arg0) : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %1#1 : !tf_executor.control
}
return
}
// Tests LoopCond -> Switch where LoopCond has a different device is ignored.
// CHECK-LABEL: func @testLoopCondSwitchLinkDifferentDevice
func @testLoopCondSwitchLinkDifferentDevice() {
tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<false> : tensor<i1>} : () -> tensor<i1>
%1:2 = tf_executor.LoopCond %0#0 : (tensor<i1>) -> (tensor<i1>, !tf_executor.control) {}
%2:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: tf_executor.Switch
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%3:3 = tf_executor.Switch %2#0, %1#0 : tensor<i64> {T = "tfdtype$DT_INT64"}
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%4:2 = tf_executor.island wraps "tf.Identity"(%3#0) : (tensor<i64>) -> tensor<i64>
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%5:2 = tf_executor.island wraps "tf.Identity"(%3#1) : (tensor<i64>) -> tensor<i64>
%6 = tf_executor.ControlTrigger %4#1, %5#1
tf_executor.fetch %6 : !tf_executor.control
}
return
}
// Tests tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink has a
// device when an intermediate op in its loop has a device.
// CHECK-LABEL: func @testNextIterationNoDevice
func @testNextIterationNoDevice() {
tf_executor.graph {
// CHECK: tf_executor.NextIteration.Source
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%0:3 = tf_executor.NextIteration.Source : tensor<i64> {T = "tfdtype$DT_INT64"}
// CHECK: tf.Identity
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
// CHECK: tf.IdentityN
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%2:2 = tf_executor.island wraps "tf.IdentityN"(%1#0) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : (tensor<i64>) -> tensor<i64>
// CHECK: tf_executor.NextIteration.Sink
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
tf_executor.NextIteration.Sink [%0#1] %2#0 : tensor<i64> {T = "tfdtype$DT_INT64"}
tf_executor.fetch %0#2 : !tf_executor.control
}
return
}
// Tests tf_executor.NextIteration with mismatched devices does not propagate
// either device.
// CHECK-LABEL: func @testNextIterationMismatchedDevices
func @testNextIterationMismatchedDevices() {
tf_executor.graph {
// CHECK: tf_executor.NextIteration.Source
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1"
%0:3 = tf_executor.NextIteration.Source : tensor<i64> {device = "/job:localhost/replica:0/task:0/device:TPU:1", T = "tfdtype$DT_INT64"}
// CHECK: "tf.Identity"({{.+}}) :
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
// CHECK: tf_executor.NextIteration.Sink
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor<i64> {device = "/job:localhost/replica:0/task:0/device:TPU:0", T = "tfdtype$DT_INT64"}
tf_executor.fetch %0#2 : !tf_executor.control
}
return
}
// CHECK-LABEL: func @testNextIterationMissingSourceDevice
func @testNextIterationMissingSourceDevice() {
tf_executor.graph {
// CHECK: tf_executor.NextIteration.Source
%0:3 = tf_executor.NextIteration.Source : tensor<i64> {T = "tfdtype$DT_INT64"}
// CHECK: "tf.Identity"({{.+}}) :
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
// CHECK: tf_executor.NextIteration.Sink
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor<i64> {device = "/job:localhost/replica:0/task:0/device:TPU:0", T = "tfdtype$DT_INT64"}
tf_executor.fetch %0#2 : !tf_executor.control
}
return
}
// CHECK-LABEL: func @testNextIterationMissingSinkDevice
func @testNextIterationMissingSinkDevice() {
tf_executor.graph {
// CHECK: tf_executor.NextIteration.Source
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1"
%0:3 = tf_executor.NextIteration.Source : tensor<i64> {device = "/job:localhost/replica:0/task:0/device:TPU:1", T = "tfdtype$DT_INT64"}
// CHECK: "tf.Identity"({{.+}}) :
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
// CHECK: tf_executor.NextIteration.Sink
tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor<i64> {T = "tfdtype$DT_INT64"}
tf_executor.fetch %0#2 : !tf_executor.control
}
return
}
// Tests unsupported functions are not modified.
// CHECK-LABEL: func @testMultipleBlockFunc
func @testMultipleBlockFunc() {
tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: tf.Identity
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %1#1 : !tf_executor.control
}
br ^bb1
^bb1:
return
}
// CHECK-LABEL: func @testMultipleGraphs
func @testMultipleGraphs() {
tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: tf.Identity
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %1#1 : !tf_executor.control
}
tf_executor.graph {
tf_executor.fetch
}
return
}
// CHECK-LABEL: func @testNoGraph
func @testNoGraph() -> tensor<i64> {
%0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: tf.Identity
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%1 = "tf.Identity"(%0) : (tensor<i64>) -> tensor<i64>
return %1 : tensor<i64>
}
// CHECK-LABEL: func @testMismatchedGraphResults
func @testMismatchedGraphResults() {
%0 = tf_executor.graph {
%1:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: tf.Identity
// CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
%2:2 = tf_executor.island wraps "tf.Identity"(%1#0) : (tensor<i64>) -> tensor<i64>
tf_executor.fetch %2#0 : tensor<i64>
}
return
}

View File

@ -358,6 +358,9 @@ CreateTPUUpdateEmbeddingEnqueueOpInputsPass();
std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass();
// Creates a pass that propagates TPU devices to users.
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDevicePropagationPass();
// Populates the supplied passmanager with the passes required to run the
// bridge.
void CreateTPUBridgePipeline(OpPassManager& pm);

View File

@ -0,0 +1,253 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <tuple>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kFuncDeviceAttr[] = "tf.device";
// Checks if a function only contains a tf_executor.graph.
bool IsSupportedGraph(FuncOp func) {
if (!llvm::hasSingleElement(func)) return false;
Block& block = func.front();
if (!llvm::hasSingleElement(block.without_terminator())) return false;
auto graph = llvm::dyn_cast<tf_executor::GraphOp>(block.front());
if (!graph) return false;
Operation* terminator = block.getTerminator();
if (graph.getNumResults() != terminator->getNumOperands()) return false;
for (auto result : llvm::zip(graph.results(), terminator->getOperands()))
if (std::get<0>(result) != std::get<1>(result)) return false;
return true;
}
// Checks if an operation of the tf_executor dialect can have TPU devices
// propagated through.
bool IsSupportedExecutorOp(Operation& op) {
auto ops_have_same_device = [](Operation* lhs, Operation* rhs) {
auto lhs_device_attr = lhs->getAttrOfType<StringAttr>(kDeviceAttr);
auto rhs_device_attr = rhs->getAttrOfType<StringAttr>(kDeviceAttr);
return (!lhs_device_attr && !rhs_device_attr) ||
(lhs_device_attr && rhs_device_attr &&
lhs_device_attr.getValue() == rhs_device_attr.getValue());
};
// Check if tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink
// pair has matching devices or no devices.
if (auto source = llvm::dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
return ops_have_same_device(source, source.GetSink());
} else if (auto sink = llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
return ops_have_same_device(sink.GetSource(), sink);
}
return llvm::isa<tf_executor::EnterOp, tf_executor::ExitOp,
tf_executor::IslandOp, tf_executor::MergeOp,
tf_executor::SwitchOp>(op);
}
// Assigns all data results to a specified device.
void PopulateDeviceForOpResults(
Operation& op, llvm::StringRef device,
llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
Operation* op_to_update = &op;
// Use tf_executor.island op if present as non v1 control flow op results are
// forwarded by a parent tf_executor.island op.
if (llvm::isa<tf_executor::IslandOp>(op_to_update->getParentOp()))
op_to_update = op_to_update->getParentOp();
for (Value result : op_to_update->getResults()) {
if (result.getType().isa<tf_executor::TokenType>()) continue;
if (result.getType().isa<tf_executor::ControlType>()) break;
value_to_device.insert({result, device});
}
}
// Checks if an operation can have TPU devices propagated through.
bool IsSupportedOpToSetDevice(Operation& op) {
return IsSupportedExecutorOp(op) ||
isa<TF::IdentityOp, TF::IdentityNOp, TF::ShapeOp>(op);
}
// Finds nonconflicting TPU device for an operation from its operands. If an
// operand has no device or a non TPU device, or if there are conflicting
// devices, and empty StringRef will be returned. Control dependencies,
// NextIteration.Source -> NextIteration.Sink token dependencies, and
// LoopCond -> Switch data dependencies are ignored.
llvm::StringRef FindDeviceFromOperands(
Operation& op,
const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
llvm::StringRef new_device;
const bool is_switch = llvm::isa<tf_executor::SwitchOp>(op);
for (Value operand : op.getOperands()) {
if (operand.getType().isa<tf_executor::TokenType>()) continue;
if (operand.getType().isa<tf_executor::ControlType>()) break;
if (is_switch &&
llvm::isa_and_nonnull<tf_executor::LoopCondOp>(operand.getDefiningOp()))
continue;
auto it = value_to_device.find(operand);
if (it == value_to_device.end()) return llvm::StringRef();
if (new_device.empty()) {
new_device = it->getSecond();
continue;
}
if (new_device != it->getSecond()) return llvm::StringRef();
}
return new_device;
}
// Propagates devices from function arguments.
void PropagateDevicesFromArguments(
FuncOp func, llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
for (BlockArgument& arg : func.getArguments()) {
auto arg_device_attr =
func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), kFuncDeviceAttr);
if (!arg_device_attr || arg_device_attr.getValue().empty() ||
!tensorflow::IsTPUDevice(arg_device_attr.getValue()))
continue;
value_to_device.insert({arg, arg_device_attr.getValue()});
}
}
// Propagates devices from operation operands to results. Updating the device of
// a tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink will result
// in multiple passes over the tf_executor.graph to propagate devices in loops.
void PropagateDevicesInGraph(
tf_executor::GraphOp graph,
llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
auto ops = graph.GetBody().without_terminator();
bool updated_next_iteration = false;
do {
updated_next_iteration = false;
for (Operation& op : ops) {
if (!IsSupportedExecutorOp(op)) continue;
Operation* op_to_update = &op;
// Unpack inner op of tf_executor.island.
if (auto island_op =
llvm::dyn_cast<tf_executor::IslandOp>(op_to_update)) {
if (!island_op.WrapsSingleOp()) continue;
op_to_update = &island_op.GetBody().front();
}
// If op already has a TPU device set, simply propagate its device.
auto device_attr = op_to_update->getAttrOfType<StringAttr>(kDeviceAttr);
const bool has_device = device_attr && !device_attr.getValue().empty();
if (has_device && tensorflow::IsTPUDevice(device_attr.getValue())) {
PopulateDeviceForOpResults(*op_to_update, device_attr.getValue(),
value_to_device);
continue;
}
// Op has an unsupported device.
if (has_device) continue;
if (!IsSupportedOpToSetDevice(*op_to_update)) continue;
llvm::StringRef new_device =
FindDeviceFromOperands(*op_to_update, value_to_device);
if (new_device.empty()) continue;
auto new_device_attr =
mlir::StringAttr::get(new_device, op_to_update->getContext());
op_to_update->setAttr(kDeviceAttr, new_device_attr);
PopulateDeviceForOpResults(*op_to_update, new_device_attr.getValue(),
value_to_device);
if (auto sink =
llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
auto source = sink.GetSource();
source.setAttr(kDeviceAttr, new_device_attr);
PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
value_to_device);
updated_next_iteration = true;
}
}
} while (updated_next_iteration);
}
// Propagates devices to function results.
void PropagateDevicesToResults(
FuncOp func, tf_executor::FetchOp fetch,
const llvm::DenseMap<Value, llvm::StringRef>& value_to_device) {
for (OpOperand& operand : fetch.getOperation()->getOpOperands()) {
if (operand.get().getType().isa<tf_executor::ControlType>()) break;
auto it = value_to_device.find(operand.get());
if (it != value_to_device.end()) {
auto device_attr = func.getResultAttrOfType<StringAttr>(
operand.getOperandNumber(), kFuncDeviceAttr);
if (device_attr && !device_attr.getValue().empty()) continue;
func.setResultAttr(operand.getOperandNumber(), kFuncDeviceAttr,
StringAttr::get(it->getSecond(), func.getContext()));
}
}
}
struct TPUDevicePropagation
: public PassWrapper<TPUDevicePropagation, FunctionPass> {
void runOnFunction() override;
};
void TPUDevicePropagation::runOnFunction() {
FuncOp func = getFunction();
if (!IsSupportedGraph(func)) return;
llvm::DenseMap<Value, llvm::StringRef> value_to_device;
PropagateDevicesFromArguments(func, value_to_device);
auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
PropagateDevicesInGraph(graph, value_to_device);
PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDevicePropagationPass() {
return std::make_unique<TPUDevicePropagation>();
}
static PassRegistration<TPUDevicePropagation> pass(
"tf-tpu-device-propagation", "Propagates TPU devices from ops to users");
} // namespace TFTPU
} // namespace mlir