STT-tensorflow/tensorflow/compiler/xla/python/py_client.h
Peter Hawkins 572442eb16 [PJRT] Fix potential misuse of PjRtBuffer::FromHostBuffer.
Add a new `PjRtBuffer::HostBufferSemantics` enum that describes the possible contracts between caller and runtime.

* Change `FromHostBuffer(..., force_copy, ...)` to `FromHostBuffer(..., host_buffer_semantics, ...)`.

We were seeing some data races between modifications to a NumPy array and JAX on CPU, due to unintended buffer aliasing. This change allows clients to control whether they want zero-copy behavior or not.

PiperOrigin-RevId: 316672280
Change-Id: Ibee296305005e0aa306a2c0aacf4b35a3d6c3ac1
2020-06-16 06:59:42 -07:00

149 lines
5.3 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
class PyBuffer;
class PyClient;
class PyExecutable;
// Custom holder types.
//
// We must keep the PyClient object alive as long as any of the runtime
// objects are alive. Since we don't have a lot of control over Python
// destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
// and ensure that each Python runtime object holds a reference to the
// PyClient. An alternative design would be to keep a single global
// singleton PyClient, although this seems less flexible, especially for
// writing tests.
//
// To maintain PyClient references, we define pybind11 holder classes that
// are custom smart pointers that also keep a reference to a PyClient.
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
// seem sufficiently flexible to describe ownership relationships in cases where
// the ownership doesn't pertain to a direct argument or return value of a
// function. Another alternative to the holder classes would be to create proxy
// objects that contain both a reference and a runtime class; holder classes
// seem less tedious to define.
// A pair of a PyClient reference and an unowned pointer to T.
template <typename T>
struct ClientAndPtr {
ClientAndPtr() = default;
// pybind11 requires that we define a constructor that takes a raw pointer,
// but it should be unreachable.
explicit ClientAndPtr(T*) {
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
}
ClientAndPtr(const ClientAndPtr&) = default;
ClientAndPtr(ClientAndPtr&&) = default;
ClientAndPtr& operator=(const ClientAndPtr&) = default;
ClientAndPtr& operator=(ClientAndPtr&&) = default;
std::shared_ptr<PyClient> client;
T* contents;
T* get() const { return contents; }
T* operator->() const { return contents; }
T& operator*() const { return *contents; }
};
// By defining a templated helper function, we can use return type deduction
// and avoid specifying types at the caller.
template <typename T>
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
ClientAndPtr<T> result;
result.client = std::move(client);
result.contents = contents;
return result;
}
// Python wrapper around PjRtClient.
// We use a wrapper class to add Python-specific functionality.
class PyClient : public std::enable_shared_from_this<PyClient> {
public:
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
const std::string& platform_name() const {
return pjrt_client_->platform_name();
}
int local_device_count() const { return pjrt_client_->local_device_count(); }
int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); }
std::vector<ClientAndPtr<Device>> Devices();
std::vector<ClientAndPtr<Device>> LocalDevices();
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
// TODO(skye): delete after all callers can handle 2D output
StatusOr<std::vector<ClientAndPtr<Device>>> GetDefaultDeviceAssignment1D(
int num_replicas);
StatusOr<ChannelHandle> CreateChannelHandle() {
return pjrt_client_->client()->CreateChannelHandle();
}
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
return pjrt_client_->client()->CreateDeviceToHostChannelHandle();
}
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
return pjrt_client_->client()->CreateHostToDeviceChannelHandle();
}
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
StatusOr<std::unique_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
pybind11::bytes HeapProfile();
private:
friend class PyBuffer;
friend class PyExecutable;
std::shared_ptr<PjRtClient> pjrt_client_;
// Pointers to intrusive doubly-linked lists of buffers and executables, used
// to iterate over all known objects when heap profiling. The list structure
// is protected by the GIL.
PyBuffer* buffers_ = nullptr;
PyExecutable* executables_ = nullptr;
};
} // namespace xla
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_