diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 71b6cd137ac..233d5a3ced3 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -868,6 +868,7 @@ cc_library( "transforms/collection_ops_util.cc", "transforms/constant_op_device_assignment.cc", "transforms/contraction_fusion.cc", + "transforms/cross_host_transfer.cc", "transforms/decompose_resource_ops_pass.cc", "transforms/device_index_selector.cc", "transforms/drop_while_shape_invariant.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index efc9d4cadcd..92d3815b576 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -383,4 +383,41 @@ This op captures all needed live-in values. }]; } +def TfDevice_SendOp : TfDevice_Op<"send", []> { + let summary = "Send a value to a host."; + + let description = [{ + Send the value to the given host with the given rendezvous key. + }]; + + let arguments = (ins + AnyType:$value, + StrAttr:$key, + StrAttr:$dst_host + ); + + let results = (outs); + + let assemblyFormat = [{$value $key $dst_host attr-dict `:` type($value)}]; +} + +def TfDevice_ReceiveOp : TfDevice_Op<"receive", []> { + let summary = "Rceive a value from a host."; + + let description = [{ + Receive a value from the given host with the given rendezvous key. + }]; + + let arguments = (ins + StrAttr:$key, + StrAttr:$src_host + ); + + let results = (outs + AnyType:$result + ); + + let assemblyFormat = [{$key $src_host attr-dict `:` type($result)}]; +} + #endif // TF_DEVICE_DIALECT diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cross_host_transfer.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cross_host_transfer.mlir new file mode 100644 index 00000000000..dd1437b4920 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/cross_host_transfer.mlir @@ -0,0 +1,67 @@ +// RUN: tf-opt --tf-cross-host-transfer %s | FileCheck %s + +// CHECK-LABEL: func @test_merge_send +func @test_merge_send() { + // CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor} + %0 = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor} : () -> tensor + + // CHECK-NEXT: tf_device.send %[[RESULT_0]] "key-0" "/job:worker/replica:0/task:1" {device = "/job:worker/replica:0/task:0/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.receive "key-0" "/job:worker/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"} + %1 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor) -> tensor + + // CHECK-NEXT: %[[RESULT_3:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"} + %2 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor) -> tensor + return +} + +// CHECK-LABEL: func @test_multiple_send +func @test_multiple_send() -> tensor { + // CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor} + %0 = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor} : () -> tensor + + // CHECK-NEXT: tf_device.send %[[RESULT_0]] "key-1" "/job:worker/replica:0/task:1" {device = "/job:worker/replica:0/task:0/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.receive "key-1" "/job:worker/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"} + %1 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor) -> tensor + + // CHECK-NEXT: tf_device.send %[[RESULT_2]] "key-2" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_3:.*]] = tf_device.receive "key-2" "/job:worker/replica:0/task:1" {device = "/job:localhost/replica:0/task:0/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_4:.*]] = "tf.Identity"(%[[RESULT_3]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} + %2 = "tf.Identity"(%1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor + + // CHECK-NEXT: return %[[RESULT_4]] : tensor + return %2 : tensor +} + +// CHECK: func @test_send_func_arg(%[[ARG_0:.*]]: tensor {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) { +func @test_send_func_arg(%arg0: tensor {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) { + // CHECK-NEXT: tf_device.send %[[ARG_0]] "key-3" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:0/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_0:.*]] = tf_device.receive "key-3" "/job:worker/replica:0/task:0" {device = "/job:localhost/replica:0/task:0/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[RESULT_0]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} + %0 = "tf.Identity"(%arg0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor + + return +} + +// CHECK: func @test_not_send_while_loop_arg(%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor<*xf32>, %[[ARG_2:.*]]: tensor) { +func @test_not_send_while_loop_arg(%arg0: tensor, %arg1: tensor<*xf32>, %arg2: tensor) { + // CHECK-NEXT: %[[RESULT_0:.*]]:2 = "tf.WhileRegion"(%[[ARG_0]], %[[ARG_1]]) ( { + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( { + // CHECK-NEXT: bb0(%[[ARG_3:.*]]: tensor, %[[ARG_4:.*]]: tensor<*xf32>) + ^bb0(%arg3: tensor, %arg4: tensor<*xf32>): + // CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[ARG_3]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"} + %2 = "tf.Identity"(%arg3) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor) -> tensor + // CHECK-NEXT: tf_device.send %[[RESULT_1]] "key-4" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_2:.*]] = tf_device.receive "key-4" "/job:worker/replica:0/task:1" {device = "/job:localhost/replica:0/task:0/device:CPU:0"} + // CHECK-NEXT: %[[RESULT_3:.*]] = "tf.NotEqual"(%[[ARG_2]], %[[RESULT_2]]) + %3 = "tf.NotEqual"(%arg2, %2) : (tensor, tensor) -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor<*xf32>): + %cst = constant dense<1> : tensor + %1 = "tf.Sub"(%arg3, %cst) : (tensor, tensor) -> tensor + "tf.Yield"(%1, %arg4) : (tensor, tensor<*xf32>) -> () + }) {is_stateless = true} : (tensor, tensor<*xf32>) -> (tensor, tensor<*xf32>) + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc new file mode 100644 index 00000000000..155d1f60adb --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc @@ -0,0 +1,158 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This pass inserts tf_device.send and tf_device.receive ops to make sure any +// argument of any op is on the same host of the op itself. + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace mlir { +namespace TF { + +namespace { + +using DeviceNameUtils = ::tensorflow::DeviceNameUtils; + +constexpr const char *kOpDeviceAttr = "device"; +constexpr const char *kArgDeviceAttr = "tf.device"; +// TODO(b/175480458): Do not assign default host once every op in the TF +// dialect has the device attribute. +constexpr const char *kDefaultHost = "/job:localhost/replica:0/task:0"; +constexpr const char *kCPUDevice = "/device:CPU:0"; + +// Return the job/replica/task from the device name as the host address. If no +// job/replica/task is specified, return /job:localhost/replica:0/task:0 as the +// default host address. +static std::string GetHost(const std::string &device) { + DeviceNameUtils::ParsedName parsed_name; + DeviceNameUtils::ParseFullName(device, &parsed_name); + parsed_name.has_id = false; + parsed_name.has_type = false; + + auto host = DeviceNameUtils::ParsedNameToString(parsed_name); + if (host.empty()) return kDefaultHost; + + return host; +} + +// Return a globally unique string as the rendezvous key for cross-host value +// transfer. +static std::string GetNextKey() { + static int64_t next_index = 0; + std::string key; + llvm::raw_string_ostream os(key); + os << "key-" << next_index; + next_index++; + + return key; +} + +struct CrossHostTransferPass + : public PassWrapper { + void runOnFunction() override; +}; + +void CrossHostTransferPass::runOnFunction() { + FuncOp func_op = getOperation(); + // This map is used to avoid transferring the same value to the same host + // multiple times. + llvm::DenseMap> + transferred_value_by_value_and_host; + + func_op.getBody().walk([&](Operation *op) { + if (op->isKnownTerminator()) return WalkResult::advance(); + + OpBuilder builder(op); + // Get the host address of the op. + std::string op_device = ""; + if (StringAttr device_attr = op->getAttrOfType(kOpDeviceAttr)) { + op_device = device_attr.getValue().str(); + } + std::string dst_host = GetHost(op_device); + + for (mlir::Value arg : op->getOperands()) { + // Get the host address of the argument. + std::string arg_device = ""; + if (BlockArgument block_arg = arg.dyn_cast()) { + // Do not send this argument if it is not a function's argument. This + // can happen when the argument is a while loop's argument. + if (block_arg.getParentRegion() != &func_op.getRegion()) continue; + + if (StringAttr device_attr = func_op.getArgAttrOfType( + block_arg.getArgNumber(), kArgDeviceAttr)) { + arg_device = device_attr.getValue().str(); + } + } else { + Operation *defining_op = arg.getDefiningOp(); + if (StringAttr device_attr = + defining_op->getAttrOfType(kOpDeviceAttr)) { + arg_device = device_attr.getValue().str(); + } + } + std::string src_host = GetHost(arg_device); + + if (src_host == dst_host) continue; + + // Re-use the transferred argument if the argument has already been + // transferred to the given host. + llvm::StringMap &transferred_value_by_host = + transferred_value_by_value_and_host[arg]; + auto iter = transferred_value_by_host.find(dst_host); + if (iter != transferred_value_by_host.end()) { + op->replaceUsesOfWith(arg, iter->second); + continue; + } + + // Create tf_device.send and tf_device.receive ops to send the argument to + // the same host of the operation. + std::string key = GetNextKey(); + auto send_op = + builder.create(op->getLoc(), arg, key, dst_host); + send_op.setAttr(kOpDeviceAttr, + builder.getStringAttr(src_host + kCPUDevice)); + + auto receive_op = builder.create( + op->getLoc(), arg.getType(), key, src_host); + receive_op.setAttr(kOpDeviceAttr, + builder.getStringAttr(dst_host + kCPUDevice)); + + transferred_value_by_host[dst_host] = receive_op.getResult(); + op->replaceUsesOfWith(arg, receive_op.getResult()); + } + return WalkResult::advance(); + }); +} + +} // namespace + +std::unique_ptr CreateCrossHostTransferPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-cross-host-transfer", + "This pass inserts tf_device.send and tf_device.receive ops to make sure " + "any argument of any op is on the same host of the op itself."); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 609fab0e30b..f6e56afb85c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -199,6 +199,10 @@ std::unique_ptr> CreateInitTextFileToImportPass(); // assignment of the result. std::unique_ptr CreateClusterTFOpsByHostPass(); +// Creates function pass to insert tf_device.send and tf_device.receive ops to +// make sure any argument of any op is on the same host of the op itself. +std::unique_ptr CreateCrossHostTransferPass(); + // Creates a pass that adds the device attribute to every tf.Const op based on // the device attribute of the operations that read its result. If the result of // a tf.Const op is read by operations placed on multiple devices, then the pass