From e6380aa5a54aca5bd20ddf7270e92c7cf7965739 Mon Sep 17 00:00:00 2001 From: Dong Lin <donglin@google.com> Date: Sun, 20 Dec 2020 19:10:41 -0800 Subject: [PATCH] [TFRT:Compiler] The tf-cross-host-transfer pass should assign rendezvous key in a deterministic way PiperOrigin-RevId: 348402376 Change-Id: If7d1a6b99045751658ab2488e5d1e5bcc9cc9cd2 --- .../transforms/cross_host_transfer.cc | 40 ++++++++++--------- .../mlir/tensorflow/transforms/passes.h | 6 +-- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc index 06dea80dbe8..b1c63b81695 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc @@ -43,7 +43,7 @@ constexpr const char *kCPUDevice = "/device:CPU:0"; // Return the job/replica/task from the device name as the host address. If no // job/replica/task is specified, return /job:localhost/replica:0/task:0 as the // default host address. -static std::string GetHost(const std::string &device) { +std::string GetHost(const std::string &device) { DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParseFullName(device, &parsed_name); parsed_name.has_id = false; @@ -55,25 +55,25 @@ static std::string GetHost(const std::string &device) { return host; } -// Return a globally unique string as the rendezvous key for cross-host value -// transfer. -static std::string GetNextKey() { - static int64_t next_index = 0; - std::string key; - llvm::raw_string_ostream os(key); - os << "key-" << next_index; - next_index++; - - return key; -} - struct CrossHostTransferPass - : public PassWrapper<CrossHostTransferPass, FunctionPass> { - void runOnFunction() override; + : public PassWrapper<CrossHostTransferPass, OperationPass<ModuleOp>> { + void runOnOperation() override; + + private: + // The key_count represents the total number of send/recv pairs generated + // before this method call. And the key_count should be incremented based + // on the send/recv pairs newly generated by this method call. + void runOnFunction(FuncOp func_op, int &key_count); }; -void CrossHostTransferPass::runOnFunction() { - FuncOp func_op = getOperation(); +void CrossHostTransferPass::runOnOperation() { + ModuleOp module = getOperation(); + int key_count = 0; + + module.walk([&](FuncOp func_op) { runOnFunction(func_op, key_count); }); +} + +void CrossHostTransferPass::runOnFunction(FuncOp func_op, int &key_count) { // This map is used to avoid transferring the same value to the same host // multiple times. llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>> @@ -125,7 +125,9 @@ void CrossHostTransferPass::runOnFunction() { // Create tf_device.send and tf_device.receive ops to send the argument to // the same host of the operation. - std::string key = GetNextKey(); + std::string key = "key-" + std::to_string(key_count); + key_count++; + auto send_op = builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host); send_op->setAttr(kOpDeviceAttr, @@ -145,7 +147,7 @@ void CrossHostTransferPass::runOnFunction() { } // namespace -std::unique_ptr<FunctionPass> CreateCrossHostTransferPass() { +std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass() { return std::make_unique<CrossHostTransferPass>(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 681e0a9e174..ca89c7dabaa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -199,9 +199,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass(); // assignment of the result. std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass(); -// Creates function pass to insert tf_device.send and tf_device.receive ops to -// make sure any argument of any op is on the same host of the op itself. -std::unique_ptr<FunctionPass> CreateCrossHostTransferPass(); +// Creates a pass to insert tf_device.send and tf_device.receive ops to make +// sure any argument of any op is on the same host of the op itself. +std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass(); // Creates a pass that adds the device attribute to every tf.Const op based on // the device attribute of the operations that read its result. If the result of