[TFRT:Compiler] Add the tf-cross-host-transfer compiler pass

This pass inserts tf_device.send and tf_device.receive ops to make sure any argument of any op is on the same host of the op itself. It allows the later compiler pass not to worry about transferring values between hosts.

PiperOrigin-RevId: 347908086
Change-Id: Iaf2bec1919d36eb9a913e488862432bc113ea15b
This commit is contained in:
Dong Lin 2020-12-16 15:20:47 -08:00 committed by TensorFlower Gardener
parent 39cbd155f4
commit d90ce10b0c
5 changed files with 267 additions and 0 deletions

View File

@ -868,6 +868,7 @@ cc_library(
"transforms/collection_ops_util.cc",
"transforms/constant_op_device_assignment.cc",
"transforms/contraction_fusion.cc",
"transforms/cross_host_transfer.cc",
"transforms/decompose_resource_ops_pass.cc",
"transforms/device_index_selector.cc",
"transforms/drop_while_shape_invariant.cc",

View File

@ -383,4 +383,41 @@ This op captures all needed live-in values.
}];
}
def TfDevice_SendOp : TfDevice_Op<"send", []> {
let summary = "Send a value to a host.";
let description = [{
Send the value to the given host with the given rendezvous key.
}];
let arguments = (ins
AnyType:$value,
StrAttr:$key,
StrAttr:$dst_host
);
let results = (outs);
let assemblyFormat = [{$value $key $dst_host attr-dict `:` type($value)}];
}
def TfDevice_ReceiveOp : TfDevice_Op<"receive", []> {
let summary = "Rceive a value from a host.";
let description = [{
Receive a value from the given host with the given rendezvous key.
}];
let arguments = (ins
StrAttr:$key,
StrAttr:$src_host
);
let results = (outs
AnyType:$result
);
let assemblyFormat = [{$key $src_host attr-dict `:` type($result)}];
}
#endif // TF_DEVICE_DIALECT

View File

@ -0,0 +1,67 @@
// RUN: tf-opt --tf-cross-host-transfer %s | FileCheck %s
// CHECK-LABEL: func @test_merge_send
func @test_merge_send() {
// CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor<f32>}
%0 = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK-NEXT: tf_device.send %[[RESULT_0]] "key-0" "/job:worker/replica:0/task:1" {device = "/job:worker/replica:0/task:0/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.receive "key-0" "/job:worker/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
%1 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[RESULT_3:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
%2 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
return
}
// CHECK-LABEL: func @test_multiple_send
func @test_multiple_send() -> tensor<f32> {
// CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor<f32>}
%0 = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK-NEXT: tf_device.send %[[RESULT_0]] "key-1" "/job:worker/replica:0/task:1" {device = "/job:worker/replica:0/task:0/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.receive "key-1" "/job:worker/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Sqrt"(%[[RESULT_1]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
%1 = "tf.Sqrt"(%0) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: tf_device.send %[[RESULT_2]] "key-2" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_3:.*]] = tf_device.receive "key-2" "/job:worker/replica:0/task:1" {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_4:.*]] = "tf.Identity"(%[[RESULT_3]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
%2 = "tf.Identity"(%1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return %[[RESULT_4]] : tensor<f32>
return %2 : tensor<f32>
}
// CHECK: func @test_send_func_arg(%[[ARG_0:.*]]: tensor<f32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) {
func @test_send_func_arg(%arg0: tensor<f32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) {
// CHECK-NEXT: tf_device.send %[[ARG_0]] "key-3" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:0/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_0:.*]] = tf_device.receive "key-3" "/job:worker/replica:0/task:0" {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[RESULT_0]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
%0 = "tf.Identity"(%arg0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<f32>) -> tensor<f32>
return
}
// CHECK: func @test_not_send_while_loop_arg(%[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<*xf32>, %[[ARG_2:.*]]: tensor<i32>) {
func @test_not_send_while_loop_arg(%arg0: tensor<i32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) {
// CHECK-NEXT: %[[RESULT_0:.*]]:2 = "tf.WhileRegion"(%[[ARG_0]], %[[ARG_1]]) ( {
%0:2 = "tf.WhileRegion"(%arg0, %arg1) ( {
// CHECK-NEXT: bb0(%[[ARG_3:.*]]: tensor<i32>, %[[ARG_4:.*]]: tensor<*xf32>)
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[ARG_3]]) {device = "/job:worker/replica:0/task:1/device:CPU:0"}
%2 = "tf.Identity"(%arg3) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: tf_device.send %[[RESULT_1]] "key-4" "/job:localhost/replica:0/task:0" {device = "/job:worker/replica:0/task:1/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_2:.*]] = tf_device.receive "key-4" "/job:worker/replica:0/task:1" {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
// CHECK-NEXT: %[[RESULT_3:.*]] = "tf.NotEqual"(%[[ARG_2]], %[[RESULT_2]])
%3 = "tf.NotEqual"(%arg2, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%3) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
%cst = constant dense<1> : tensor<i32>
%1 = "tf.Sub"(%arg3, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.Yield"(%1, %arg4) : (tensor<i32>, tensor<*xf32>) -> ()
}) {is_stateless = true} : (tensor<i32>, tensor<*xf32>) -> (tensor<i32>, tensor<*xf32>)
return
}

View File

@ -0,0 +1,158 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass inserts tf_device.send and tf_device.receive ops to make sure any
// argument of any op is on the same host of the op itself.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace mlir {
namespace TF {
namespace {
using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
constexpr const char *kOpDeviceAttr = "device";
constexpr const char *kArgDeviceAttr = "tf.device";
// TODO(b/175480458): Do not assign default host once every op in the TF
// dialect has the device attribute.
constexpr const char *kDefaultHost = "/job:localhost/replica:0/task:0";
constexpr const char *kCPUDevice = "/device:CPU:0";
// Return the job/replica/task from the device name as the host address. If no
// job/replica/task is specified, return /job:localhost/replica:0/task:0 as the
// default host address.
static std::string GetHost(const std::string &device) {
DeviceNameUtils::ParsedName parsed_name;
DeviceNameUtils::ParseFullName(device, &parsed_name);
parsed_name.has_id = false;
parsed_name.has_type = false;
auto host = DeviceNameUtils::ParsedNameToString(parsed_name);
if (host.empty()) return kDefaultHost;
return host;
}
// Return a globally unique string as the rendezvous key for cross-host value
// transfer.
static std::string GetNextKey() {
static int64_t next_index = 0;
std::string key;
llvm::raw_string_ostream os(key);
os << "key-" << next_index;
next_index++;
return key;
}
struct CrossHostTransferPass
: public PassWrapper<CrossHostTransferPass, FunctionPass> {
void runOnFunction() override;
};
void CrossHostTransferPass::runOnFunction() {
FuncOp func_op = getOperation();
// This map is used to avoid transferring the same value to the same host
// multiple times.
llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
transferred_value_by_value_and_host;
func_op.getBody().walk([&](Operation *op) {
if (op->isKnownTerminator()) return WalkResult::advance();
OpBuilder builder(op);
// Get the host address of the op.
std::string op_device = "";
if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kOpDeviceAttr)) {
op_device = device_attr.getValue().str();
}
std::string dst_host = GetHost(op_device);
for (mlir::Value arg : op->getOperands()) {
// Get the host address of the argument.
std::string arg_device = "";
if (BlockArgument block_arg = arg.dyn_cast<BlockArgument>()) {
// Do not send this argument if it is not a function's argument. This
// can happen when the argument is a while loop's argument.
if (block_arg.getParentRegion() != &func_op.getRegion()) continue;
if (StringAttr device_attr = func_op.getArgAttrOfType<StringAttr>(
block_arg.getArgNumber(), kArgDeviceAttr)) {
arg_device = device_attr.getValue().str();
}
} else {
Operation *defining_op = arg.getDefiningOp();
if (StringAttr device_attr =
defining_op->getAttrOfType<StringAttr>(kOpDeviceAttr)) {
arg_device = device_attr.getValue().str();
}
}
std::string src_host = GetHost(arg_device);
if (src_host == dst_host) continue;
// Re-use the transferred argument if the argument has already been
// transferred to the given host.
llvm::StringMap<mlir::Value> &transferred_value_by_host =
transferred_value_by_value_and_host[arg];
auto iter = transferred_value_by_host.find(dst_host);
if (iter != transferred_value_by_host.end()) {
op->replaceUsesOfWith(arg, iter->second);
continue;
}
// Create tf_device.send and tf_device.receive ops to send the argument to
// the same host of the operation.
std::string key = GetNextKey();
auto send_op =
builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host);
send_op.setAttr(kOpDeviceAttr,
builder.getStringAttr(src_host + kCPUDevice));
auto receive_op = builder.create<tf_device::ReceiveOp>(
op->getLoc(), arg.getType(), key, src_host);
receive_op.setAttr(kOpDeviceAttr,
builder.getStringAttr(dst_host + kCPUDevice));
transferred_value_by_host[dst_host] = receive_op.getResult();
op->replaceUsesOfWith(arg, receive_op.getResult());
}
return WalkResult::advance();
});
}
} // namespace
std::unique_ptr<FunctionPass> CreateCrossHostTransferPass() {
return std::make_unique<CrossHostTransferPass>();
}
static PassRegistration<CrossHostTransferPass> pass(
"tf-cross-host-transfer",
"This pass inserts tf_device.send and tf_device.receive ops to make sure "
"any argument of any op is on the same host of the op itself.");
} // namespace TF
} // namespace mlir

View File

@ -199,6 +199,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass();
// assignment of the result.
std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
// Creates function pass to insert tf_device.send and tf_device.receive ops to
// make sure any argument of any op is on the same host of the op itself.
std::unique_ptr<FunctionPass> CreateCrossHostTransferPass();
// Creates a pass that adds the device attribute to every tf.Const op based on
// the device attribute of the operations that read its result. If the result of
// a tf.Const op is read by operations placed on multiple devices, then the pass