Use a handle instead of an object.

PiperOrigin-RevId: 350188697
Change-Id: Iaef8cab63d2a3ad626021d409d658fe70f3b29dc
This commit is contained in:
Jean-Baptiste Lespiau 2021-01-05 11:54:10 -08:00 committed by TensorFlower Gardener
parent 0f93dd9528
commit ec2403ba4f
2 changed files with 4 additions and 4 deletions

View File

@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval( StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics) { PjRtClient::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) { if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty()); TF_RET_CHECK(!pjrt_client_->local_devices().empty());
@ -123,7 +123,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
return buffer; return buffer;
} }
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval( StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics) { PjRtClient::HostBufferSemantics host_buffer_semantics) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer, std::unique_ptr<PjRtBuffer> buffer,

View File

@ -124,10 +124,10 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval( StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics); PjRtClient::HostBufferSemantics host_buffer_semantics);
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval( StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics); PjRtClient::HostBufferSemantics host_buffer_semantics);
StatusOr<std::shared_ptr<PyExecutable>> Compile( StatusOr<std::shared_ptr<PyExecutable>> Compile(