[XLA:Python] Fix destructor ordering problem.

Wait for devices to quiesce before deleting the PythonRefManager.

Will fix the remaining problem in https://github.com/google/jax/issues/927 when deployed in jaxlib.

PiperOrigin-RevId: 255309774
This commit is contained in:
Peter Hawkins 2019-06-26 18:51:56 -07:00 committed by TensorFlower Gardener
parent e4ab9c20b2
commit 6ae9600988

View File

@ -214,13 +214,18 @@ class PyLocalClient {
protected:
std::string platform_name_;
LocalClient* client_;
// py_ref_manager_ must come after devices_ in the class destruction order
// (i.e., appear first in the class.)
// Destruction of devices waits for them to quiesce; callbacks on device
// streams may refer to py_ref_manager_ and we must wait for them to complete.
PythonRefManager py_ref_manager_;
std::vector<std::unique_ptr<Device>> devices_;
se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
tensorflow::thread::ThreadPool h2d_transfer_pool_;
PythonRefManager py_ref_manager_;
};
// Holds a reference from Python to one or more device buffers.