Add a new `PjRtBuffer::HostBufferSemantics` enum that describes the possible contracts between caller and runtime. * Change `FromHostBuffer(..., force_copy, ...)` to `FromHostBuffer(..., host_buffer_semantics, ...)`. We were seeing some data races between modifications to a NumPy array and JAX on CPU, due to unintended buffer aliasing. This change allows clients to control whether they want zero-copy behavior or not. PiperOrigin-RevId: 316672280 Change-Id: Ibee296305005e0aa306a2c0aacf4b35a3d6c3ac1
149 lines
5.3 KiB
C++
149 lines
5.3 KiB
C++
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "absl/types/optional.h"
|
|
#include "pybind11/pybind11.h"
|
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
|
#include "tensorflow/compiler/xla/statusor.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
|
|
namespace xla {
|
|
|
|
class PyBuffer;
|
|
class PyClient;
|
|
class PyExecutable;
|
|
|
|
// Custom holder types.
|
|
//
|
|
// We must keep the PyClient object alive as long as any of the runtime
|
|
// objects are alive. Since we don't have a lot of control over Python
|
|
// destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
|
|
// and ensure that each Python runtime object holds a reference to the
|
|
// PyClient. An alternative design would be to keep a single global
|
|
// singleton PyClient, although this seems less flexible, especially for
|
|
// writing tests.
|
|
//
|
|
// To maintain PyClient references, we define pybind11 holder classes that
|
|
// are custom smart pointers that also keep a reference to a PyClient.
|
|
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
|
|
// seem sufficiently flexible to describe ownership relationships in cases where
|
|
// the ownership doesn't pertain to a direct argument or return value of a
|
|
// function. Another alternative to the holder classes would be to create proxy
|
|
// objects that contain both a reference and a runtime class; holder classes
|
|
// seem less tedious to define.
|
|
|
|
// A pair of a PyClient reference and an unowned pointer to T.
|
|
template <typename T>
|
|
struct ClientAndPtr {
|
|
ClientAndPtr() = default;
|
|
// pybind11 requires that we define a constructor that takes a raw pointer,
|
|
// but it should be unreachable.
|
|
explicit ClientAndPtr(T*) {
|
|
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
|
|
}
|
|
|
|
ClientAndPtr(const ClientAndPtr&) = default;
|
|
ClientAndPtr(ClientAndPtr&&) = default;
|
|
ClientAndPtr& operator=(const ClientAndPtr&) = default;
|
|
ClientAndPtr& operator=(ClientAndPtr&&) = default;
|
|
|
|
std::shared_ptr<PyClient> client;
|
|
T* contents;
|
|
|
|
T* get() const { return contents; }
|
|
T* operator->() const { return contents; }
|
|
T& operator*() const { return *contents; }
|
|
};
|
|
|
|
// By defining a templated helper function, we can use return type deduction
|
|
// and avoid specifying types at the caller.
|
|
template <typename T>
|
|
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
|
|
ClientAndPtr<T> result;
|
|
result.client = std::move(client);
|
|
result.contents = contents;
|
|
return result;
|
|
}
|
|
|
|
// Python wrapper around PjRtClient.
|
|
// We use a wrapper class to add Python-specific functionality.
|
|
class PyClient : public std::enable_shared_from_this<PyClient> {
|
|
public:
|
|
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
|
|
|
|
PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
|
|
|
|
const std::string& platform_name() const {
|
|
return pjrt_client_->platform_name();
|
|
}
|
|
int local_device_count() const { return pjrt_client_->local_device_count(); }
|
|
int device_count() const { return pjrt_client_->device_count(); }
|
|
int host_id() const { return pjrt_client_->host_id(); }
|
|
|
|
std::vector<ClientAndPtr<Device>> Devices();
|
|
std::vector<ClientAndPtr<Device>> LocalDevices();
|
|
|
|
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
|
|
GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
|
|
|
|
// TODO(skye): delete after all callers can handle 2D output
|
|
StatusOr<std::vector<ClientAndPtr<Device>>> GetDefaultDeviceAssignment1D(
|
|
int num_replicas);
|
|
|
|
StatusOr<ChannelHandle> CreateChannelHandle() {
|
|
return pjrt_client_->client()->CreateChannelHandle();
|
|
}
|
|
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
|
return pjrt_client_->client()->CreateDeviceToHostChannelHandle();
|
|
}
|
|
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
|
return pjrt_client_->client()->CreateHostToDeviceChannelHandle();
|
|
}
|
|
|
|
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
|
|
const pybind11::object& argument, Device* device, bool force_copy,
|
|
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
|
|
|
|
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
|
const XlaComputation& computation, CompileOptions options);
|
|
|
|
pybind11::bytes HeapProfile();
|
|
|
|
private:
|
|
friend class PyBuffer;
|
|
friend class PyExecutable;
|
|
|
|
std::shared_ptr<PjRtClient> pjrt_client_;
|
|
|
|
// Pointers to intrusive doubly-linked lists of buffers and executables, used
|
|
// to iterate over all known objects when heap profiling. The list structure
|
|
// is protected by the GIL.
|
|
PyBuffer* buffers_ = nullptr;
|
|
PyExecutable* executables_ = nullptr;
|
|
};
|
|
|
|
} // namespace xla
|
|
|
|
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
|
|
|
|
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
|