From d39bcbaefbc8082428ec2414e4d07769b3ede7d2 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Fri, 23 Oct 2020 18:06:58 -0700 Subject: [PATCH] Add pass that propagates TPU devices from operands to uses. This pass finds values placed on TPU and propagates them to certain users. Ops include certain tf_executor dialect ops and more trivial tf dialect ops (Identity, IdentityN, Shape). PiperOrigin-RevId: 338778252 Change-Id: Idbea53ea5a52df036cdd249f8d894e052a81a6a5 --- tensorflow/compiler/mlir/tensorflow/BUILD | 1 + .../tests/tpu_device_propagation.mlir | 383 ++++++++++++++++++ .../mlir/tensorflow/transforms/passes.h | 3 + .../transforms/tpu_device_propagation.cc | 253 ++++++++++++ 4 files changed, 640 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tpu_device_propagation.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index a2c32c46b11..77615f65d9e 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -884,6 +884,7 @@ cc_library( "transforms/tpu_cluster_cleanup_attributes.cc", "transforms/tpu_cluster_formation.cc", "transforms/tpu_colocate_composite_resource_ops.cc", + "transforms/tpu_device_propagation.cc", "transforms/tpu_dynamic_layout_pass.cc", "transforms/tpu_dynamic_padding_mapper.cc", "transforms/tpu_extract_head_tail_outside_compilation.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_device_propagation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_device_propagation.mlir new file mode 100644 index 00000000000..39d6df513fa --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_device_propagation.mlir @@ -0,0 +1,383 @@ +// RUN: tf-opt %s -tf-tpu-device-propagation | FileCheck %s + +// Tests function passthrough values. + +// CHECK-LABEL: func @testArgToRet +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testArgToRet(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor { + %0 = tf_executor.graph { + tf_executor.fetch %arg0 : tensor + } + return %0 : tensor +} + +// Tests supported ops. + +// CHECK-LABEL: func @testIdentityOp +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testIdentityOp(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor { + %0 = tf_executor.graph { + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor) -> tensor + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor +} + +// CHECK-LABEL: func @testIdentityNOp +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testIdentityNOp(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + // CHECK: tf.IdentityN + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + tf_executor.fetch %1#0, %1#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: func @testShapeOp +// CHECK-SAME: ({{%.+}}: tensor<*xi64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testShapeOp(%arg0: tensor<*xi64> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor { + %0 = tf_executor.graph { + // CHECK: tf.Shape + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island wraps "tf.Shape"(%arg0) : (tensor<*xi64>) -> tensor + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor +} + +// CHECK-LABEL: func @testEnterOp +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testEnterOp(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor { + %0 = tf_executor.graph { + // CHECK: tf_executor.Enter + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.Enter %arg0 frame "frame" : tensor + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor +} + +// CHECK-LABEL: func @testExitOp +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testExitOp(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor { + %0 = tf_executor.graph { + // CHECK: tf_executor.Exit + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.Exit %arg0 : tensor + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor +} + +// CHECK-LABEL: func @testMergeOp +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testMergeOp(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + // CHECK: tf_executor.Merge + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:3 = tf_executor.Merge %arg0, %arg1 : tensor + tf_executor.fetch %1#0, %1#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: func @testSwitchOp +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testSwitchOp(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) { + tf_executor.graph { + // CHECK: tf_executor.Switch + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %0:3 = tf_executor.Switch %arg0, %arg1 : tensor {T = "tfdtype$DT_INT64"} + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor) -> tensor + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %2:2 = tf_executor.island wraps "tf.Identity"(%0#1) : (tensor) -> tensor + %3 = tf_executor.ControlTrigger %1#1, %2#1 + tf_executor.fetch %3 : !tf_executor.control + } + return +} + +// Tests unsupported op does not have TPU device propagated. + +// CHECK-LABEL: func @testUnsupportedOp +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> tensor +func @testUnsupportedOp(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor { + %0 = tf_executor.graph { + // CHECK: tf.UnsupportedOp + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island wraps "tf.UnsupportedOp"(%arg0) : (tensor) -> tensor + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor +} + +// Tests empty devices are overwritten. + +// CHECK-LABEL: func @testEmptyDeviceOverwritten +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testEmptyDeviceOverwritten(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor {tf.device = ""}) { + %0 = tf_executor.graph { + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island wraps "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor +} + +// Tests only devices are propagated when all operands are on the same TPU +// device. + +// CHECK-LABEL: func @testOperandsNoDevice +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor) +// CHECK-SAME: -> (tensor, tensor) +func @testOperandsNoDevice(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + // CHECK: tf.IdentityN + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + tf_executor.fetch %1#0, %1#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: func @testOperandsDifferentDevice +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, {{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}) +// CHECK-SAME: -> (tensor, tensor) +func @testOperandsDifferentDevice(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg1: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + // CHECK: tf.IdentityN + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:1" + %1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + tf_executor.fetch %1#0, %1#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// Tests op with operand on different device does not have its device +// overwritten. + +// CHECK-LABEL: func @testDifferentOperandAndOpDevice +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +func @testDifferentOperandAndOpDevice(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) { + tf_executor.graph { + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1" + %0:2 = tf_executor.island wraps "tf.Identity"(%arg0) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : (tensor) -> tensor + tf_executor.fetch %0#1 : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @testDifferentOperandAndResultDevice +// CHECK-SAME: ({{%.+}}: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) +// CHECK-SAME: -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}) +func @testDifferentOperandAndResultDevice(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}) { + %0 = tf_executor.graph { + tf_executor.fetch %arg0 : tensor + } + return %0 : tensor +} + +// Tests non TPU devices are not propagated. + +// CHECK-LABEL: func @testNonTPUDevice +func @testNonTPUDevice(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) { + tf_executor.graph { + // CHECK: tf.Identity + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:CPU:0" + %0:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor) -> tensor + tf_executor.fetch %0#1 : !tf_executor.control + } + return +} + +// Tests control dependencies are ignored for propagating devices. + +// CHECK-LABEL: func @testControlDependenciesIgnored +func @testControlDependenciesIgnored(%arg0: tensor) { + tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: tf.Identity + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island(%0#1) wraps "tf.Identity"(%arg0) : (tensor) -> tensor + tf_executor.fetch %1#1 : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @testControlDependenciesMismatchedDevices +func @testControlDependenciesMismatchedDevices(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) { + tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:1", value = dense<0> : tensor} : () -> tensor + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island(%0#1) wraps "tf.Identity"(%arg0) : (tensor) -> tensor + tf_executor.fetch %1#1 : !tf_executor.control + } + return +} + +// Tests LoopCond -> Switch where LoopCond has a different device is ignored. + +// CHECK-LABEL: func @testLoopCondSwitchLinkDifferentDevice +func @testLoopCondSwitchLinkDifferentDevice() { + tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense : tensor} : () -> tensor + %1:2 = tf_executor.LoopCond %0#0 : (tensor) -> (tensor, !tf_executor.control) {} + %2:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: tf_executor.Switch + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %3:3 = tf_executor.Switch %2#0, %1#0 : tensor {T = "tfdtype$DT_INT64"} + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %4:2 = tf_executor.island wraps "tf.Identity"(%3#0) : (tensor) -> tensor + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %5:2 = tf_executor.island wraps "tf.Identity"(%3#1) : (tensor) -> tensor + %6 = tf_executor.ControlTrigger %4#1, %5#1 + tf_executor.fetch %6 : !tf_executor.control + } + return +} + +// Tests tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink has a +// device when an intermediate op in its loop has a device. + +// CHECK-LABEL: func @testNextIterationNoDevice +func @testNextIterationNoDevice() { + tf_executor.graph { + // CHECK: tf_executor.NextIteration.Source + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %0:3 = tf_executor.NextIteration.Source : tensor {T = "tfdtype$DT_INT64"} + // CHECK: tf.Identity + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor) -> tensor + // CHECK: tf.IdentityN + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %2:2 = tf_executor.island wraps "tf.IdentityN"(%1#0) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : (tensor) -> tensor + // CHECK: tf_executor.NextIteration.Sink + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + tf_executor.NextIteration.Sink [%0#1] %2#0 : tensor {T = "tfdtype$DT_INT64"} + tf_executor.fetch %0#2 : !tf_executor.control + } + return +} + +// Tests tf_executor.NextIteration with mismatched devices does not propagate +// either device. + +// CHECK-LABEL: func @testNextIterationMismatchedDevices +func @testNextIterationMismatchedDevices() { + tf_executor.graph { + // CHECK: tf_executor.NextIteration.Source + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1" + %0:3 = tf_executor.NextIteration.Source : tensor {device = "/job:localhost/replica:0/task:0/device:TPU:1", T = "tfdtype$DT_INT64"} + // CHECK: "tf.Identity"({{.+}}) : + %1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor) -> tensor + // CHECK: tf_executor.NextIteration.Sink + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor {device = "/job:localhost/replica:0/task:0/device:TPU:0", T = "tfdtype$DT_INT64"} + tf_executor.fetch %0#2 : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @testNextIterationMissingSourceDevice +func @testNextIterationMissingSourceDevice() { + tf_executor.graph { + // CHECK: tf_executor.NextIteration.Source + %0:3 = tf_executor.NextIteration.Source : tensor {T = "tfdtype$DT_INT64"} + // CHECK: "tf.Identity"({{.+}}) : + %1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor) -> tensor + // CHECK: tf_executor.NextIteration.Sink + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0" + tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor {device = "/job:localhost/replica:0/task:0/device:TPU:0", T = "tfdtype$DT_INT64"} + tf_executor.fetch %0#2 : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @testNextIterationMissingSinkDevice +func @testNextIterationMissingSinkDevice() { + tf_executor.graph { + // CHECK: tf_executor.NextIteration.Source + // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1" + %0:3 = tf_executor.NextIteration.Source : tensor {device = "/job:localhost/replica:0/task:0/device:TPU:1", T = "tfdtype$DT_INT64"} + // CHECK: "tf.Identity"({{.+}}) : + %1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor) -> tensor + // CHECK: tf_executor.NextIteration.Sink + tf_executor.NextIteration.Sink [%0#1] %1#0 : tensor {T = "tfdtype$DT_INT64"} + tf_executor.fetch %0#2 : !tf_executor.control + } + return +} + +// Tests unsupported functions are not modified. + +// CHECK-LABEL: func @testMultipleBlockFunc +func @testMultipleBlockFunc() { + tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: tf.Identity + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor) -> tensor + tf_executor.fetch %1#1 : !tf_executor.control + } + br ^bb1 +^bb1: + return +} + +// CHECK-LABEL: func @testMultipleGraphs +func @testMultipleGraphs() { + tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: tf.Identity + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor) -> tensor + tf_executor.fetch %1#1 : !tf_executor.control + } + tf_executor.graph { + tf_executor.fetch + } + return +} + +// CHECK-LABEL: func @testNoGraph +func @testNoGraph() -> tensor { + %0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: tf.Identity + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %1 = "tf.Identity"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @testMismatchedGraphResults +func @testMismatchedGraphResults() { + %0 = tf_executor.graph { + %1:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: tf.Identity + // CHECK-NOT: device = "/job:localhost/replica:0/task:0/device:TPU:0" + %2:2 = tf_executor.island wraps "tf.Identity"(%1#0) : (tensor) -> tensor + tf_executor.fetch %2#0 : tensor + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 4a12c80c8d1..3a7c1b6300b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -358,6 +358,9 @@ CreateTPUUpdateEmbeddingEnqueueOpInputsPass(); std::unique_ptr> CreateTPUExtractOutsideCompilationPass(); +// Creates a pass that propagates TPU devices to users. +std::unique_ptr> CreateTPUDevicePropagationPass(); + // Populates the supplied passmanager with the passes required to run the // bridge. void CreateTPUBridgePipeline(OpPassManager& pm); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc new file mode 100644 index 00000000000..6771ad1b923 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc @@ -0,0 +1,253 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +constexpr char kDeviceAttr[] = "device"; +constexpr char kFuncDeviceAttr[] = "tf.device"; + +// Checks if a function only contains a tf_executor.graph. +bool IsSupportedGraph(FuncOp func) { + if (!llvm::hasSingleElement(func)) return false; + + Block& block = func.front(); + if (!llvm::hasSingleElement(block.without_terminator())) return false; + + auto graph = llvm::dyn_cast(block.front()); + if (!graph) return false; + + Operation* terminator = block.getTerminator(); + if (graph.getNumResults() != terminator->getNumOperands()) return false; + for (auto result : llvm::zip(graph.results(), terminator->getOperands())) + if (std::get<0>(result) != std::get<1>(result)) return false; + + return true; +} + +// Checks if an operation of the tf_executor dialect can have TPU devices +// propagated through. +bool IsSupportedExecutorOp(Operation& op) { + auto ops_have_same_device = [](Operation* lhs, Operation* rhs) { + auto lhs_device_attr = lhs->getAttrOfType(kDeviceAttr); + auto rhs_device_attr = rhs->getAttrOfType(kDeviceAttr); + return (!lhs_device_attr && !rhs_device_attr) || + (lhs_device_attr && rhs_device_attr && + lhs_device_attr.getValue() == rhs_device_attr.getValue()); + }; + + // Check if tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink + // pair has matching devices or no devices. + if (auto source = llvm::dyn_cast(op)) { + return ops_have_same_device(source, source.GetSink()); + } else if (auto sink = llvm::dyn_cast(op)) { + return ops_have_same_device(sink.GetSource(), sink); + } + + return llvm::isa(op); +} + +// Assigns all data results to a specified device. +void PopulateDeviceForOpResults( + Operation& op, llvm::StringRef device, + llvm::DenseMap& value_to_device) { + Operation* op_to_update = &op; + // Use tf_executor.island op if present as non v1 control flow op results are + // forwarded by a parent tf_executor.island op. + if (llvm::isa(op_to_update->getParentOp())) + op_to_update = op_to_update->getParentOp(); + + for (Value result : op_to_update->getResults()) { + if (result.getType().isa()) continue; + if (result.getType().isa()) break; + + value_to_device.insert({result, device}); + } +} + +// Checks if an operation can have TPU devices propagated through. +bool IsSupportedOpToSetDevice(Operation& op) { + return IsSupportedExecutorOp(op) || + isa(op); +} + +// Finds nonconflicting TPU device for an operation from its operands. If an +// operand has no device or a non TPU device, or if there are conflicting +// devices, and empty StringRef will be returned. Control dependencies, +// NextIteration.Source -> NextIteration.Sink token dependencies, and +// LoopCond -> Switch data dependencies are ignored. +llvm::StringRef FindDeviceFromOperands( + Operation& op, + const llvm::DenseMap& value_to_device) { + llvm::StringRef new_device; + const bool is_switch = llvm::isa(op); + for (Value operand : op.getOperands()) { + if (operand.getType().isa()) continue; + if (operand.getType().isa()) break; + + if (is_switch && + llvm::isa_and_nonnull(operand.getDefiningOp())) + continue; + + auto it = value_to_device.find(operand); + if (it == value_to_device.end()) return llvm::StringRef(); + + if (new_device.empty()) { + new_device = it->getSecond(); + continue; + } + + if (new_device != it->getSecond()) return llvm::StringRef(); + } + + return new_device; +} + +// Propagates devices from function arguments. +void PropagateDevicesFromArguments( + FuncOp func, llvm::DenseMap& value_to_device) { + for (BlockArgument& arg : func.getArguments()) { + auto arg_device_attr = + func.getArgAttrOfType(arg.getArgNumber(), kFuncDeviceAttr); + if (!arg_device_attr || arg_device_attr.getValue().empty() || + !tensorflow::IsTPUDevice(arg_device_attr.getValue())) + continue; + value_to_device.insert({arg, arg_device_attr.getValue()}); + } +} + +// Propagates devices from operation operands to results. Updating the device of +// a tf_executor.NextIteration.Source/tf_executor.NextIteration.Sink will result +// in multiple passes over the tf_executor.graph to propagate devices in loops. +void PropagateDevicesInGraph( + tf_executor::GraphOp graph, + llvm::DenseMap& value_to_device) { + auto ops = graph.GetBody().without_terminator(); + + bool updated_next_iteration = false; + do { + updated_next_iteration = false; + for (Operation& op : ops) { + if (!IsSupportedExecutorOp(op)) continue; + + Operation* op_to_update = &op; + // Unpack inner op of tf_executor.island. + if (auto island_op = + llvm::dyn_cast(op_to_update)) { + if (!island_op.WrapsSingleOp()) continue; + op_to_update = &island_op.GetBody().front(); + } + + // If op already has a TPU device set, simply propagate its device. + auto device_attr = op_to_update->getAttrOfType(kDeviceAttr); + const bool has_device = device_attr && !device_attr.getValue().empty(); + if (has_device && tensorflow::IsTPUDevice(device_attr.getValue())) { + PopulateDeviceForOpResults(*op_to_update, device_attr.getValue(), + value_to_device); + continue; + } + + // Op has an unsupported device. + if (has_device) continue; + + if (!IsSupportedOpToSetDevice(*op_to_update)) continue; + + llvm::StringRef new_device = + FindDeviceFromOperands(*op_to_update, value_to_device); + if (new_device.empty()) continue; + + auto new_device_attr = + mlir::StringAttr::get(new_device, op_to_update->getContext()); + op_to_update->setAttr(kDeviceAttr, new_device_attr); + PopulateDeviceForOpResults(*op_to_update, new_device_attr.getValue(), + value_to_device); + + if (auto sink = + llvm::dyn_cast(op_to_update)) { + auto source = sink.GetSource(); + source.setAttr(kDeviceAttr, new_device_attr); + PopulateDeviceForOpResults(*source, new_device_attr.getValue(), + value_to_device); + updated_next_iteration = true; + } + } + } while (updated_next_iteration); +} + +// Propagates devices to function results. +void PropagateDevicesToResults( + FuncOp func, tf_executor::FetchOp fetch, + const llvm::DenseMap& value_to_device) { + for (OpOperand& operand : fetch.getOperation()->getOpOperands()) { + if (operand.get().getType().isa()) break; + auto it = value_to_device.find(operand.get()); + if (it != value_to_device.end()) { + auto device_attr = func.getResultAttrOfType( + operand.getOperandNumber(), kFuncDeviceAttr); + if (device_attr && !device_attr.getValue().empty()) continue; + func.setResultAttr(operand.getOperandNumber(), kFuncDeviceAttr, + StringAttr::get(it->getSecond(), func.getContext())); + } + } +} + +struct TPUDevicePropagation + : public PassWrapper { + void runOnFunction() override; +}; + +void TPUDevicePropagation::runOnFunction() { + FuncOp func = getFunction(); + if (!IsSupportedGraph(func)) return; + + llvm::DenseMap value_to_device; + PropagateDevicesFromArguments(func, value_to_device); + auto graph = llvm::cast(func.front().front()); + PropagateDevicesInGraph(graph, value_to_device); + PropagateDevicesToResults(func, graph.GetFetch(), value_to_device); +} + +} // namespace + +std::unique_ptr> CreateTPUDevicePropagationPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-device-propagation", "Propagates TPU devices from ops to users"); + +} // namespace TFTPU +} // namespace mlir