[TFRT:Compiler] The tf-cross-host-transfer pass should assign rendezvous key in a deterministic way

PiperOrigin-RevId: 348402376
Change-Id: If7d1a6b99045751658ab2488e5d1e5bcc9cc9cd2
This commit is contained in:
Dong Lin 2020-12-20 19:10:41 -08:00 committed by TensorFlower Gardener
parent bd53eb5e75
commit e6380aa5a5
2 changed files with 24 additions and 22 deletions
tensorflow/compiler/mlir/tensorflow/transforms

View File

@ -43,7 +43,7 @@ constexpr const char *kCPUDevice = "/device:CPU:0";
// Return the job/replica/task from the device name as the host address. If no
// job/replica/task is specified, return /job:localhost/replica:0/task:0 as the
// default host address.
static std::string GetHost(const std::string &device) {
std::string GetHost(const std::string &device) {
DeviceNameUtils::ParsedName parsed_name;
DeviceNameUtils::ParseFullName(device, &parsed_name);
parsed_name.has_id = false;
@ -55,25 +55,25 @@ static std::string GetHost(const std::string &device) {
return host;
}
// Return a globally unique string as the rendezvous key for cross-host value
// transfer.
static std::string GetNextKey() {
static int64_t next_index = 0;
std::string key;
llvm::raw_string_ostream os(key);
os << "key-" << next_index;
next_index++;
return key;
}
struct CrossHostTransferPass
: public PassWrapper<CrossHostTransferPass, FunctionPass> {
void runOnFunction() override;
: public PassWrapper<CrossHostTransferPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
private:
// The key_count represents the total number of send/recv pairs generated
// before this method call. And the key_count should be incremented based
// on the send/recv pairs newly generated by this method call.
void runOnFunction(FuncOp func_op, int &key_count);
};
void CrossHostTransferPass::runOnFunction() {
FuncOp func_op = getOperation();
void CrossHostTransferPass::runOnOperation() {
ModuleOp module = getOperation();
int key_count = 0;
module.walk([&](FuncOp func_op) { runOnFunction(func_op, key_count); });
}
void CrossHostTransferPass::runOnFunction(FuncOp func_op, int &key_count) {
// This map is used to avoid transferring the same value to the same host
// multiple times.
llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
@ -125,7 +125,9 @@ void CrossHostTransferPass::runOnFunction() {
// Create tf_device.send and tf_device.receive ops to send the argument to
// the same host of the operation.
std::string key = GetNextKey();
std::string key = "key-" + std::to_string(key_count);
key_count++;
auto send_op =
builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host);
send_op->setAttr(kOpDeviceAttr,
@ -145,7 +147,7 @@ void CrossHostTransferPass::runOnFunction() {
} // namespace
std::unique_ptr<FunctionPass> CreateCrossHostTransferPass() {
std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass() {
return std::make_unique<CrossHostTransferPass>();
}

View File

@ -199,9 +199,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass();
// assignment of the result.
std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
// Creates function pass to insert tf_device.send and tf_device.receive ops to
// make sure any argument of any op is on the same host of the op itself.
std::unique_ptr<FunctionPass> CreateCrossHostTransferPass();
// Creates a pass to insert tf_device.send and tf_device.receive ops to make
// sure any argument of any op is on the same host of the op itself.
std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass();
// Creates a pass that adds the device attribute to every tf.Const op based on
// the device attribute of the operations that read its result. If the result of