[TFRT:Compiler] The tf-cross-host-transfer pass should assign rendezvous key in a deterministic way
PiperOrigin-RevId: 348402376 Change-Id: If7d1a6b99045751658ab2488e5d1e5bcc9cc9cd2
This commit is contained in:
parent
bd53eb5e75
commit
e6380aa5a5
tensorflow/compiler/mlir/tensorflow/transforms
@ -43,7 +43,7 @@ constexpr const char *kCPUDevice = "/device:CPU:0";
|
||||
// Return the job/replica/task from the device name as the host address. If no
|
||||
// job/replica/task is specified, return /job:localhost/replica:0/task:0 as the
|
||||
// default host address.
|
||||
static std::string GetHost(const std::string &device) {
|
||||
std::string GetHost(const std::string &device) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
DeviceNameUtils::ParseFullName(device, &parsed_name);
|
||||
parsed_name.has_id = false;
|
||||
@ -55,25 +55,25 @@ static std::string GetHost(const std::string &device) {
|
||||
return host;
|
||||
}
|
||||
|
||||
// Return a globally unique string as the rendezvous key for cross-host value
|
||||
// transfer.
|
||||
static std::string GetNextKey() {
|
||||
static int64_t next_index = 0;
|
||||
std::string key;
|
||||
llvm::raw_string_ostream os(key);
|
||||
os << "key-" << next_index;
|
||||
next_index++;
|
||||
|
||||
return key;
|
||||
}
|
||||
|
||||
struct CrossHostTransferPass
|
||||
: public PassWrapper<CrossHostTransferPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
: public PassWrapper<CrossHostTransferPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
// The key_count represents the total number of send/recv pairs generated
|
||||
// before this method call. And the key_count should be incremented based
|
||||
// on the send/recv pairs newly generated by this method call.
|
||||
void runOnFunction(FuncOp func_op, int &key_count);
|
||||
};
|
||||
|
||||
void CrossHostTransferPass::runOnFunction() {
|
||||
FuncOp func_op = getOperation();
|
||||
void CrossHostTransferPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
int key_count = 0;
|
||||
|
||||
module.walk([&](FuncOp func_op) { runOnFunction(func_op, key_count); });
|
||||
}
|
||||
|
||||
void CrossHostTransferPass::runOnFunction(FuncOp func_op, int &key_count) {
|
||||
// This map is used to avoid transferring the same value to the same host
|
||||
// multiple times.
|
||||
llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
|
||||
@ -125,7 +125,9 @@ void CrossHostTransferPass::runOnFunction() {
|
||||
|
||||
// Create tf_device.send and tf_device.receive ops to send the argument to
|
||||
// the same host of the operation.
|
||||
std::string key = GetNextKey();
|
||||
std::string key = "key-" + std::to_string(key_count);
|
||||
key_count++;
|
||||
|
||||
auto send_op =
|
||||
builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host);
|
||||
send_op->setAttr(kOpDeviceAttr,
|
||||
@ -145,7 +147,7 @@ void CrossHostTransferPass::runOnFunction() {
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<FunctionPass> CreateCrossHostTransferPass() {
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass() {
|
||||
return std::make_unique<CrossHostTransferPass>();
|
||||
}
|
||||
|
||||
|
@ -199,9 +199,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass();
|
||||
// assignment of the result.
|
||||
std::unique_ptr<FunctionPass> CreateClusterTFOpsByHostPass();
|
||||
|
||||
// Creates function pass to insert tf_device.send and tf_device.receive ops to
|
||||
// make sure any argument of any op is on the same host of the op itself.
|
||||
std::unique_ptr<FunctionPass> CreateCrossHostTransferPass();
|
||||
// Creates a pass to insert tf_device.send and tf_device.receive ops to make
|
||||
// sure any argument of any op is on the same host of the op itself.
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass();
|
||||
|
||||
// Creates a pass that adds the device attribute to every tf.Const op based on
|
||||
// the device attribute of the operations that read its result. If the result of
|
||||
|
Loading…
Reference in New Issue
Block a user