diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index c94a206a926..be61bd74419 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -91,6 +91,7 @@ class PyClient : public std::enable_shared_from_this { explicit PyClient(std::shared_ptr pjrt_client); PjRtClient* pjrt_client() const { return pjrt_client_.get(); } + std::shared_ptr shared_pjrt_client() { return pjrt_client_; } const std::string& platform_name() const { return pjrt_client_->platform_name();