[XLA:Python] Fix destructor ordering problem.
Wait for devices to quiesce before deleting the PythonRefManager. Will fix the remaining problem in https://github.com/google/jax/issues/927 when deployed in jaxlib. PiperOrigin-RevId: 255309774
This commit is contained in:
parent
e4ab9c20b2
commit
6ae9600988
@ -214,13 +214,18 @@ class PyLocalClient {
|
||||
protected:
|
||||
std::string platform_name_;
|
||||
LocalClient* client_;
|
||||
|
||||
// py_ref_manager_ must come after devices_ in the class destruction order
|
||||
// (i.e., appear first in the class.)
|
||||
// Destruction of devices waits for them to quiesce; callbacks on device
|
||||
// streams may refer to py_ref_manager_ and we must wait for them to complete.
|
||||
PythonRefManager py_ref_manager_;
|
||||
|
||||
std::vector<std::unique_ptr<Device>> devices_;
|
||||
se::DeviceMemoryAllocator* allocator_;
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||
|
||||
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
||||
|
||||
PythonRefManager py_ref_manager_;
|
||||
};
|
||||
|
||||
// Holds a reference from Python to one or more device buffers.
|
||||
|
Loading…
Reference in New Issue
Block a user