[XLA:Python] Fix destructor ordering problem.
Wait for devices to quiesce before deleting the PythonRefManager. Will fix the remaining problem in https://github.com/google/jax/issues/927 when deployed in jaxlib. PiperOrigin-RevId: 255309774
This commit is contained in:
parent
e4ab9c20b2
commit
6ae9600988
@ -214,13 +214,18 @@ class PyLocalClient {
|
|||||||
protected:
|
protected:
|
||||||
std::string platform_name_;
|
std::string platform_name_;
|
||||||
LocalClient* client_;
|
LocalClient* client_;
|
||||||
|
|
||||||
|
// py_ref_manager_ must come after devices_ in the class destruction order
|
||||||
|
// (i.e., appear first in the class.)
|
||||||
|
// Destruction of devices waits for them to quiesce; callbacks on device
|
||||||
|
// streams may refer to py_ref_manager_ and we must wait for them to complete.
|
||||||
|
PythonRefManager py_ref_manager_;
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Device>> devices_;
|
std::vector<std::unique_ptr<Device>> devices_;
|
||||||
se::DeviceMemoryAllocator* allocator_;
|
se::DeviceMemoryAllocator* allocator_;
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||||
|
|
||||||
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
||||||
|
|
||||||
PythonRefManager py_ref_manager_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Holds a reference from Python to one or more device buffers.
|
// Holds a reference from Python to one or more device buffers.
|
||||||
|
Loading…
Reference in New Issue
Block a user