From e6380aa5a54aca5bd20ddf7270e92c7cf7965739 Mon Sep 17 00:00:00 2001
From: Dong Lin <donglin@google.com>
Date: Sun, 20 Dec 2020 19:10:41 -0800
Subject: [PATCH] [TFRT:Compiler] The tf-cross-host-transfer pass should assign
 rendezvous key in a deterministic way

PiperOrigin-RevId: 348402376
Change-Id: If7d1a6b99045751658ab2488e5d1e5bcc9cc9cd2
---
 .../transforms/cross_host_transfer.cc         | 40 ++++++++++---------
 .../mlir/tensorflow/transforms/passes.h       |  6 +--
 2 files changed, 24 insertions(+), 22 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc
index 06dea80dbe8..b1c63b81695 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cross_host_transfer.cc
@@ -43,7 +43,7 @@ constexpr const char *kCPUDevice = "/device:CPU:0";
 // Return the job/replica/task from the device name as the host address. If no
 // job/replica/task is specified, return /job:localhost/replica:0/task:0 as the
 // default host address.
-static std::string GetHost(const std::string &device) {
+std::string GetHost(const std::string &device) {
   DeviceNameUtils::ParsedName parsed_name;
   DeviceNameUtils::ParseFullName(device, &parsed_name);
   parsed_name.has_id = false;
@@ -55,25 +55,25 @@ static std::string GetHost(const std::string &device) {
   return host;
 }
 
-// Return a globally unique string as the rendezvous key for cross-host value
-// transfer.
-static std::string GetNextKey() {
-  static int64_t next_index = 0;
-  std::string key;
-  llvm::raw_string_ostream os(key);
-  os << "key-" << next_index;
-  next_index++;
-
-  return key;
-}
-
 struct CrossHostTransferPass
-    : public PassWrapper<CrossHostTransferPass, FunctionPass> {
-  void runOnFunction() override;
+    : public PassWrapper<CrossHostTransferPass, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+
+ private:
+  // The key_count represents the total number of send/recv pairs generated
+  // before this method call. And the key_count should be incremented based
+  // on the send/recv pairs newly generated by this method call.
+  void runOnFunction(FuncOp func_op, int &key_count);
 };
 
-void CrossHostTransferPass::runOnFunction() {
-  FuncOp func_op = getOperation();
+void CrossHostTransferPass::runOnOperation() {
+  ModuleOp module = getOperation();
+  int key_count = 0;
+
+  module.walk([&](FuncOp func_op) { runOnFunction(func_op, key_count); });
+}
+
+void CrossHostTransferPass::runOnFunction(FuncOp func_op, int &key_count) {
   // This map is used to avoid transferring the same value to the same host
   // multiple times.
   llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
@@ -125,7 +125,9 @@ void CrossHostTransferPass::runOnFunction() {
 
       // Create tf_device.send and tf_device.receive ops to send the argument to
       // the same host of the operation.
-      std::string key = GetNextKey();
+      std::string key = "key-" + std::to_string(key_count);
+      key_count++;
+
       auto send_op =
           builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host);
       send_op->setAttr(kOpDeviceAttr,
@@ -145,7 +147,7 @@ void CrossHostTransferPass::runOnFunction() {
 
 }  // namespace
 
-std::unique_ptr<FunctionPass> CreateCrossHostTransferPass() {
+std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass() {
   return std::make_unique<CrossHostTransferPass>();
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index 681e0a9e174..ca89c7dabaa 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -199,9 +199,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass();
 // assignment of the result.
 std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
 
-// Creates function pass to insert tf_device.send and tf_device.receive ops to
-// make sure any argument of any op is on the same host of the op itself.
-std::unique_ptr<FunctionPass> CreateCrossHostTransferPass();
+// Creates a pass to insert tf_device.send and tf_device.receive ops to make
+// sure any argument of any op is on the same host of the op itself.
+std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass();
 
 // Creates a pass that adds the device attribute to every tf.Const op based on
 // the device attribute of the operations that read its result. If the result of