STT-tensorflow/tensorflow/compiler/xla/python/jax_jit.cc
Jean-Baptiste Lespiau 5fff2a4bca Add support for more numpy types, and use a map-lookup design.
PiperOrigin-RevId: 351424260
Change-Id: I7d6e81574e0e3583ab4dbf59fdde75cfb5c951d2
2021-01-12 13:23:55 -08:00

1076 lines
43 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This files implements the `jax.jit` dispatch and just-in-time feature.
//
// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward
// based on passed arguments dtypes/shapes/identity) the execution to a
// just-in-time compiled XLA Executable. All of that is done in C++ for
// performance reasons.
//
// This file contains the utilities to:
// (a) inspect arguments and describe their structure, dtype/shapes, etc.
// (b) keep a mapping from function signatures to compiled XLA Executables.
#include "tensorflow/compiler/xla/python/jax_jit.h"
#include <Python.h>
#include <exception>
#include <memory>
#include <stdexcept>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/notification.h"
#include "absl/types/optional.h"
#include "pybind11/cast.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_executable.h"
#include "tensorflow/compiler/xla/python/pytree.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/status.h"
namespace xla {
namespace py = pybind11;
// TODO(phawkins): Add support for Tracers.
// TODO(jblespiau): Use absl Status.
// TODO(jblespiau): Remove the "xla::" prefixes when not needed.
std::string ArgSignature::DebugString() const {
std::string result = "";
if (weak_type) {
absl::StrAppend(&result, "weak_");
}
absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
return result;
}
bool CallSignature::operator==(const CallSignature& other) const {
return std::tie(dynamic_positional_args_treedef, keyword_args,
dynamic_args_signatures, device) ==
std::tie(other.dynamic_positional_args_treedef, other.keyword_args,
other.dynamic_args_signatures, other.device) &&
// `==` on py:objects is the Python `is`. We need equal.
std::equal(
static_args.begin(), static_args.end(), other.static_args.begin(),
other.static_args.end(),
[](const py::object& a, const py::object& b) {
try {
return a.equal(b);
} catch (const py::error_already_set& e) {
throw std::invalid_argument(absl::StrCat(
"static arguments should be comparable using __eq__."
"The following error was raised when comparing two "
"objects of types ",
py::cast<std::string>(py::str(py::type::of(a))), " and ",
py::cast<std::string>(py::str(py::type::of(b))),
". The error was:\n", e.what()));
}
});
}
void CallSignature::IncRef() const {
for (const auto& kw : keyword_args) {
kw.key.inc_ref();
}
}
void CallSignature::DecRef() const {
for (const auto& kw : keyword_args) {
kw.key.dec_ref();
}
}
namespace {
thread_local bool disable_jit;
void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; }
bool GetDisableJit() { return disable_jit; }
} // namespace
std::string CallSignature::DebugString() const {
std::vector<std::string> static_args_str;
static_args_str.reserve(static_args.size());
for (auto& static_arg : static_args) {
static_args_str.emplace_back(py::cast<std::string>(py::str(static_arg)));
}
std::vector<std::string> signature_str;
signature_str.reserve(dynamic_args_signatures.size());
for (auto& arg_signature : dynamic_args_signatures) {
signature_str.emplace_back(arg_signature.DebugString());
}
std::vector<std::string> tree_def_str;
signature_str.reserve(dynamic_positional_args_treedef.size());
for (auto& tree_def : dynamic_positional_args_treedef) {
tree_def_str.emplace_back(tree_def.ToString());
}
std::vector<std::string> keyword_names;
keyword_names.reserve(keyword_args.size());
for (auto& kwarg_entry : keyword_args) {
keyword_names.emplace_back(py::cast<std::string>(kwarg_entry.key));
tree_def_str.emplace_back(kwarg_entry.value_treedef.ToString());
}
return absl::StrCat(
static_args.size(), " static_args: ", absl::StrJoin(static_args_str, ","),
"\n", // new line
keyword_args.size(), " keyword args:", absl::StrJoin(keyword_names, ","),
"\n", // new-line
dynamic_positional_args_treedef.size(), " positional args.\n",
dynamic_args_signatures.size(),
" dynamic args (positional+keyword):\n - ",
absl::StrJoin(signature_str, ", "), "\n - ",
absl::StrJoin(tree_def_str, " | "));
}
template <typename H>
H AbslHashValue(H h, const CallSignature& s) {
h = H::combine_contiguous(std::move(h),
s.dynamic_positional_args_treedef.data(),
s.dynamic_positional_args_treedef.size());
h = H::combine_contiguous(std::move(h), s.keyword_args.data(),
s.keyword_args.size());
h = H::combine_contiguous(std::move(h), s.dynamic_args_signatures.data(),
s.dynamic_args_signatures.size());
h = H::combine(std::move(h), s.device);
for (const auto& static_arg : s.static_args) {
ssize_t hash;
try {
hash = py::hash(static_arg);
} catch (const py::error_already_set& e) {
throw std::invalid_argument(absl::StrCat(
"Non-hashable static arguments are not supported. An error occured "
"while trying to hash an object of type ",
py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
e.what(), "\n"));
}
h = H::combine(std::move(h), hash);
}
return h;
}
// Filter out static arguments, flatten and concatenate other arguments (i.e.
// dynamic positional and keyword arguments), filling `arguments` in place.
Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments) {
if (static_argnums.size() > args.size()) {
return InvalidArgument(
"%s", "[jaxjit] Error with static argnums, executing the Python path.");
}
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
static_argnums.size());
arguments.signature.dynamic_positional_args_treedef.reserve(
args.size() - static_argnums.size());
// Positional arguments.
for (size_t i = 0; i < args.size(); ++i) {
if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
static_argnums.end()) {
PyTreeDef pytree_def;
pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
} else {
arguments.signature.static_args.emplace_back(
// borrow is mandatory here.
py::reinterpret_borrow<py::object>(args[i]));
}
}
// Keyword arguments.
std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
py_kwargs.end());
// We first intern the keys, then sort them (by name, as in the Python path)
// (see also PyTreeDef::Flatten) and then create the signatures.
// TODO(jblespiau): We should be able to sort the keys by interned-key
// pointers, but this requires the Python compilation to do the same.
arguments.signature.keyword_args.resize(kwargs.size());
for (size_t i = 0; i < kwargs.size(); ++i) {
// Intern the key if not already interned.
if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
PyObject* key = kwargs[i].first.ptr();
kwargs[i].first.inc_ref();
PyUnicode_InternInPlace(&key);
arguments.keep_alive_objects.push_back(
py::reinterpret_steal<py::object>(key));
kwargs[i].first = py::handle(key);
}
}
std::sort(kwargs.begin(), kwargs.end(),
[](const std::pair<py::handle, py::handle>& a,
const std::pair<py::handle, py::handle>& b) {
return a.first < b.first;
});
for (size_t i = 0; i < kwargs.size(); ++i) {
arguments.signature.keyword_args[i].key = kwargs[i].first;
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
kwargs[i].second, arguments.flat_dynamic_args);
}
return Status::OK();
}
namespace {
struct NumpyScalarTypes {
py::object np_bool;
py::object np_int8;
py::object np_int16;
py::object np_int32;
py::object np_int64;
py::object np_uint8;
py::object np_uint16;
py::object np_uint32;
py::object np_uint64;
py::object np_float16;
py::object np_float32;
py::object np_float64;
py::object np_complex64;
py::object np_complex128;
py::object np_longlong;
py::object np_intc;
};
const NumpyScalarTypes& GetNumpyScalarTypes() {
static const NumpyScalarTypes* singleton = []() {
// Use Designated initializers when they are available.
const auto numpy = py::module::import("numpy");
NumpyScalarTypes* dtypes = new NumpyScalarTypes();
dtypes->np_bool = py::object(numpy.attr("bool_"));
dtypes->np_int8 = py::object(numpy.attr("int8"));
dtypes->np_int16 = py::object(numpy.attr("int16"));
dtypes->np_int32 = py::object(numpy.attr("int32"));
dtypes->np_int64 = py::object(numpy.attr("int64"));
dtypes->np_uint8 = py::object(numpy.attr("uint8"));
dtypes->np_uint16 = py::object(numpy.attr("uint16"));
dtypes->np_uint32 = py::object(numpy.attr("uint32"));
dtypes->np_uint64 = py::object(numpy.attr("uint64"));
dtypes->np_float16 = py::object(numpy.attr("float16"));
dtypes->np_float32 = py::object(numpy.attr("float32"));
dtypes->np_float64 = py::object(numpy.attr("float64"));
dtypes->np_complex64 = py::object(numpy.attr("complex64"));
dtypes->np_complex128 = py::object(numpy.attr("complex128"));
dtypes->np_longlong = py::object(numpy.attr("longlong"));
dtypes->np_intc = py::object(numpy.attr("intc"));
return dtypes;
}();
return *singleton;
}
const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) {
// TODO(jblespiau): Use GetNumpyScalarTypes instead.
static const auto* int64_dt = new py::dtype("int64");
static const auto* int32_dt = new py::dtype("int32");
static const auto* uint64_dt = new py::dtype("uint64");
static const auto* uint32_dt = new py::dtype("uint32");
static const auto* float64_dt = new py::dtype("float64");
static const auto* float32_dt = new py::dtype("float32");
static const auto* complex64_dt = new py::dtype("complex64");
static const auto* complex128_dt = new py::dtype("complex128");
if (dtype.equal(*int64_dt)) {
return int32_dt;
}
if (dtype.equal(*float64_dt)) {
return float32_dt;
}
if (dtype.equal(*uint64_dt)) {
return uint32_dt;
}
if (dtype.equal(*complex128_dt)) {
return complex64_dt;
}
return nullptr;
}
// The equivalent of the Python jax/lazy.py::is_trivial:
// return (type(lexpr.input) is ArrayVar and
// lexpr.dims == tuple(range(len(lexpr.shape))))
//
// Expects *only* `None` or a LazyExpr` object.
bool IsTrivialLazyExpr(py::handle lexpr) {
if (lexpr.is_none()) {
return true;
}
static const auto* lazy_module =
new py::module(py::module::import("jax.lazy"));
auto input = py::getattr(lexpr, "input");
if (!input.get_type().is(lazy_module->attr("ArrayVar"))) {
return false;
}
py::tuple dims = py::cast<py::tuple>(lexpr.attr("dims"));
py::tuple shape = py::cast<py::tuple>(lexpr.attr("shape"));
for (int i = 0; i < shape.size(); ++i) {
if (dims[i].is_none()) {
return false;
}
if (py::cast<int>(dims[i]) != i) {
return false;
}
}
return true;
}
bool IsFloat0(py::array arg) {
static const auto* dtypes_module =
new py::module(py::module::import("jax.dtypes"));
static const auto* float0_dtype =
new py::handle(dtypes_module->attr("float0"));
return float0_dtype->is(arg.attr("dtype"));
}
template <typename CppType, typename Pybind11Type>
std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
const py::handle& scalar, xla::PjRtClient* client,
xla::PjRtDevice* device) {
CppType data = py::cast<Pybind11Type>(scalar);
// Work around for https://github.com/pybind/pybind11/issues/2786
if (PyErr_Occurred()) {
throw py::error_already_set();
}
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
return ValueOrThrow(client->BufferFromHostBuffer(
&data, shape,
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
device));
}
} // namespace
namespace {
using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
py::handle, PjRtDevice*, bool, xla::PyClient&)>;
DevicePutResult HandleBool(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64, xla::PyClient& pyclient) {
return DevicePutResult(ConvertToScalarBuffer<bool, py::bool_>(
h, pyclient.pjrt_client(), to_device),
/*weak_type=*/true);
}
DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device,
bool jax_enable_x64, xla::PyClient& pyclient) {
if (jax_enable_x64) {
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
obj, pyclient.pjrt_client(), to_device),
/*weak_type=*/true);
} else {
return DevicePutResult(ConvertToScalarBuffer<int, py::int_>(
obj, pyclient.pjrt_client(), to_device),
/*weak_type=*/true);
}
}
template <bool weak_type>
StatusOr<DevicePutResult> HandleFloat(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
if (jax_enable_x64) {
return DevicePutResult(ConvertToScalarBuffer<double, py::float_>(
h, pyclient.pjrt_client(), to_device),
/*weak_type=*/weak_type);
} else {
return DevicePutResult(ConvertToScalarBuffer<float, py::float_>(
h, pyclient.pjrt_client(), to_device),
/*weak_type=*/weak_type);
}
}
template <bool weak_type>
StatusOr<DevicePutResult> HandleComplex(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
// This branch is also taken for np.complex128:
// isinstance(np.complex128(3), complex) returns True
// isinstance(np.complex64(3), complex) returns False
Py_complex result = PyComplex_AsCComplex(h.ptr());
if (result.real == -1.0 && PyErr_Occurred()) {
PyErr_Clear();
throw std::runtime_error("Could not convert the complex number");
}
if (jax_enable_x64) {
xla::complex128 data(result.real, result.imag);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
return DevicePutResult(
ValueOrThrow(pyclient.pjrt_client()->BufferFromHostBuffer(
&data, shape,
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
nullptr, to_device)),
/*weak_type=*/weak_type);
} else {
xla::complex64 data(result.real, result.imag);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
return DevicePutResult(
ValueOrThrow(pyclient.pjrt_client()->BufferFromHostBuffer(
&data, shape,
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
nullptr, to_device)),
/*weak_type=*/weak_type);
}
}
StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
return InvalidArgument(
"Non-trivial lazy expression not supported in C++. "
"Falling back to Python.");
}
PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
// Same block as in the previous `if (is_py_buffer)`.
if (buffer->device().contents == to_device) {
return DevicePutResult(buffer->buffer(), weak_type);
} else {
std::unique_ptr<PjRtBuffer> copied_buffer =
ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
return DevicePutResult(std::move(copied_buffer), weak_type);
}
}
// Do not convert types, and only call PjRtBufferFromPyval, independently
// of the value of jax_enable_x64.
DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
std::unique_ptr<xla::PjRtBuffer> buffer =
ValueOrThrow(pyclient.PjRtBufferFromPyval(
h, to_device,
/*force_copy=*/false, /*host_buffer_semantics=*/
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
}
DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64, xla::PyClient& pyclient) {
if (jax_enable_x64) {
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
h, pyclient.pjrt_client(), to_device),
/*weak_type=*/false);
} else {
return DevicePutResult(ConvertToScalarBuffer<int, py::int_>(
h, pyclient.pjrt_client(), to_device),
/*weak_type=*/false);
}
}
DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64, xla::PyClient& pyclient) {
if (jax_enable_x64) {
std::unique_ptr<xla::PjRtBuffer> buffer =
ValueOrThrow(pyclient.PjRtBufferFromPyval(
h, to_device,
/*force_copy=*/false, /*host_buffer_semantics=*/
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
} else {
static const auto* numpy = new py::module(py::module::import("numpy"));
const auto& np_array = numpy->attr("array");
// Note that this is calling back to Python!
std::unique_ptr<xla::PjRtBuffer> buffer =
ValueOrThrow(pyclient.PjRtBufferFromPyval(
np_array(h, py::dtype("uint32")), to_device,
/*force_copy=*/false, /*host_buffer_semantics=*/
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
}
}
StatusOr<DevicePutResult> HandleNdarray(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
py::array numpy_array = py::cast<py::array>(h);
if (IsFloat0(numpy_array)) {
return InvalidArgument("%s",
"float0 numpy arrays not supported in C++. "
"Falling back to Python.");
}
// If jax_enable_x64 is not set, we need to coerce 32 bits types.
// Note that this is calling back to Python!
if (!jax_enable_x64) {
const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
if (to_dtype) {
static const auto* numpy = new py::module(py::module::import("numpy"));
const auto& np_array = numpy->attr("array");
numpy_array = np_array(numpy_array, *to_dtype);
}
}
std::unique_ptr<xla::PjRtBuffer> buffer =
ValueOrThrow(pyclient.PjRtBufferFromPyval(
numpy_array, to_device,
/*force_copy=*/false, /*host_buffer_semantics=*/
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
}
} // namespace
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
[] {
auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
const auto numpy = py::module::import("numpy");
const auto xla_module = py::module::import("jax.interpreters.xla");
const auto& device_array = xla_module.attr("_DeviceArray");
// Python base types.
(*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = HandleBool;
(*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandleInt;
(*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = HandleFloat<true>;
(*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
HandleComplex<true>;
// DeviceArray and co.
const auto pxla_module = py::module::import("jax.interpreters.pxla");
const auto& sda = pxla_module.attr("ShardedDeviceArray");
(*p)[device_array.ptr()] = HandleDeviceArray;
(*p)[py::type::handle_of<DeviceArrayBase>().ptr()] = HandleDeviceArray;
(*p)[sda.ptr()] = HandleBufferFromPyval;
// Numpy arrays.
(*p)[numpy.attr("ndarray").ptr()] = HandleNdarray;
// Numpy scalar types. For some of them, we share the handler with
// Python types (np_int64, np_float64, np_complex128).
(*p)[dtypes.np_bool.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_int8.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_int16.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_int32.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_int64.ptr()] = HandleNpBool;
(*p)[dtypes.np_uint8.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_uint16.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_uint32.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_uint64.ptr()] = HandleUint64;
(*p)[dtypes.np_float16.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_float32.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_float64.ptr()] = HandleFloat<false>;
(*p)[dtypes.np_complex64.ptr()] = HandleBufferFromPyval;
(*p)[dtypes.np_complex128.ptr()] = HandleComplex<false>;
(*p)[dtypes.np_longlong.ptr()] = HandleNpBool;
(*p)[dtypes.np_intc.ptr()] = HandleBufferFromPyval;
return p;
}();
auto res = handlers->find(arg.get_type().ptr());
if (res == handlers->end()) {
for (auto base_class : arg.get_type().attr("mro")()) {
res = handlers->find(base_class.ptr());
if (res != handlers->end()) {
return res->second(arg, to_device, jax_enable_x64, pyclient);
}
}
return InvalidArgument(
"%s", absl::StrCat(
"Not supported: The C++ jax jit execution path, only accepts "
"DeviceArray, Numpy arrays scalars of supported types "
"(see implementation), or Python scalars. Got type ",
py::cast<std::string>(py::str(arg.get_type()))));
} else {
return res->second(arg, to_device, jax_enable_x64, pyclient);
}
}
namespace {
struct CacheEntry {
std::shared_ptr<xla::PyExecutable> executable;
PyTreeDef out_pytree_def;
// We use Python types within the vector because this is what we will be
// returning to Python. No need to convert back and forth.
// We need py::object to maintain the objects alive.
std::vector<py::object> out_avals;
// The processing done in `AddCacheEntry` ensures that LazyExpr are stored as
// `py::none()`.
std::vector<py::object> out_lazy_exprs;
py::object sticky_device;
// Ensures a single thread performs the compilation for a given executable.
//
// The first thread (holding the GIL) will create the CacheEntry associated to
// a signature and if the object has been insterted already, other threads
// will wait for the notification.
absl::Notification compilation_complete;
absl::optional<Status> compilation_error = absl::nullopt;
// Trivial computation will fallback to Python.
// Running a jax(pmap) will also fallback to Python.
bool fall_back_to_python = false;
};
// A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
// bookkeeping of the different signatures used and the dispatch of calls to
// the correct underlying `PyExecutable`. This class is thread-safe.
class CompiledFunction {
public:
CompiledFunction(py::function fun, py::function cache_miss,
py::function get_device, py::function get_jax_enable_x64,
py::function get_jax_disable_jit,
std::vector<int> static_argnums);
~CompiledFunction();
// This function will:
// (a) flatten the inputs using pytree
// (b) get buffer objects from the arguments
// (c) call the executable
// (d) construct `DeviceArray` objects from the outputs
// (e) reconstruct the `PyTree`.
py::object Call(py::args args, py::kwargs kwargs);
// This allows `inspect.signature(cpp_jitted_f)` from Python.
py::object __signature__() {
static const auto* inspect = new py::module(py::module::import("inspect"));
return inspect->attr("signature")(fun_);
}
int cache_size() const { return executables_.size(); }
private:
// Returns nullptr if not present in the cache.
CacheEntry* GetCacheEntryIfPresent(const CallSignature& signature);
// Should never return nullptr.
CacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs,
const CallSignature& signature,
py::object out_and_fastpath_data);
bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); }
bool always_fallback_to_python_ = false;
const py::function fun_; // The Python function to jit.
// See JAX _cpp_jit in api.py for documentation.
const py::function cache_miss_;
// We need to know the static arguments to remove them from the arguments
// passed to the underlying PyExecutable. In sorted order.
std::vector<int> static_argnums_;
// We need a `unique_ptr` here to ensure value pointer stability.
absl::flat_hash_map<CallSignature, std::unique_ptr<CacheEntry>> executables_;
// As top-level functions are decorated with `jax.jit`, when
// `CompiledFunction` is being instantiated from Python, the clients are not
// yet available (done after GoogleInit). They will be during the first call
// to `Call`.
// A function taking no arguments and returning the default device and whether
// jax.jit has been committed to it.
const py::function get_jax_enable_x64_;
const py::function get_jax_disable_jit_;
const py::function get_device_;
// The writing of the following is protected by the mutex.
absl::Mutex mu_;
// The value of the Python flag. The value will be computed only during the
// first object call, because GoogleInit must have been executed.
absl::optional<bool> jax_enable_x64_ = absl::nullopt;
absl::optional<bool> jax_disable_jit_ = absl::nullopt;
// The logic if the following:
// - if `device` or `backend` are not specified to `jax.jit`, we will use
// the input sticky buffer device, or `default_device_` if there is no
// such sticky buffer.
// - When one of `device` or `backend` is specified, this will determine
// the `default_device_` which will be used as the targeted device. In
// which case, we will always copy input buffers to this device.
std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
xla::ClientAndPtr<PjRtDevice> default_pydevice_;
xla::PjRtDevice* default_device_ = nullptr;
bool is_committed_;
};
CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss,
py::function get_device,
py::function get_jax_enable_x64,
py::function get_jax_disable_jit,
std::vector<int> static_argnums)
: fun_(std::move(fun)),
cache_miss_(std::move(cache_miss)),
static_argnums_(std::move(static_argnums)),
get_jax_enable_x64_(get_jax_enable_x64),
get_jax_disable_jit_(get_jax_disable_jit),
get_device_(std::move(get_device)) {
std::sort(static_argnums_.begin(), static_argnums_.end());
}
CompiledFunction::~CompiledFunction() {
for (const auto& entry : executables_) {
entry.first.DecRef();
}
}
// Converts flattened arguments contained in ParsedArgumentsAsBuffers in
// place. If arguments are `DeviceArray`, they must all be on the same `Device`.
//
// Returns `OkStatus()` on success. Returning an error should lead to calling
// the Python fallback.
Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
xla::PjRtDevice* default_device, bool is_committed,
ParsedArgumentsAsBuffers& arguments) {
std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
auto& keep_alive = arguments.keep_alive;
int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
arg_buffers.reserve(num_flat_dynamic_args);
arguments.signature.dynamic_args_signatures.reserve(num_flat_dynamic_args);
static const auto* xla_module =
new py::module(py::module::import("jax.interpreters.xla"));
const auto& device_array = xla_module->attr("_DeviceArray");
// When the jitted function is not committed, we first check whether any
// sticky `DeviceArray` is present and on which device they live. See also:
// https://github.com/google/jax/pull/1884
// https://github.com/google/jax/pull/1916 for the rationale why the
// computation follows the data locality.
// It's also similar to PyTorch's behavior.
xla::PjRtDevice* data_device = nullptr;
if (is_committed) {
data_device = default_device;
} else {
for (py::handle arg : arguments.flat_dynamic_args) {
// We specically only deal with DeviceArray (not ShardedDeviceArray).
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
xla::PyBuffer* buffer;
if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
continue;
}
try {
// This can fail, e.g. when device_buffer is a `DeviceConstant`.
buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
} catch (const py::cast_error& e) {
return InvalidArgument(
"%s",
absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
"`device_buffer` field is of type ",
py::cast<std::string>(
arg.attr("device_buffer").get_type().str()),
" while a `PyBuffer` was expected."
));
}
xla::PjRtDevice* device = buffer->buffer()->device();
if (data_device && (device != data_device)) {
throw std::invalid_argument(absl::StrCat(
"primitive arguments must be colocated on the same device ("
"C++ jax.jit). Arguments are on devices: ",
device->DebugString(), " and ", data_device->DebugString()));
} else {
data_device = device;
}
}
}
}
if (!data_device) {
// No `DeviceArray` were found default to `default_device`.
data_device = default_device;
}
CHECK(data_device);
arguments.signature.device = data_device;
for (py::handle arg : arguments.flat_dynamic_args) {
TF_ASSIGN_OR_RETURN(DevicePutResult on_device,
DevicePut(arg, data_device, jax_enable_x64, pyclient));
PjRtBuffer* buffer = on_device.buffer;
arg_buffers.push_back(buffer);
if (on_device.owned_buffer) {
keep_alive.emplace_back(std::move(on_device.owned_buffer));
}
ArgSignature sig(buffer->on_host_shape().element_type(),
buffer->on_host_shape().dimensions(), on_device.weak_type);
arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
}
return Status::OK();
}
CacheEntry* CompiledFunction::GetCacheEntryIfPresent(
const CallSignature& signature) {
auto found_iterator = executables_.find(signature);
if (found_iterator != executables_.end()) { // Cache hit!
if (!found_iterator->second->compilation_complete.HasBeenNotified()) {
py::gil_scoped_release gil_release;
found_iterator->second->compilation_complete.WaitForNotification();
}
if (found_iterator->second->compilation_error) {
throw std::invalid_argument(
found_iterator->second->compilation_error.value().error_message());
}
return found_iterator->second.get();
}
return nullptr;
}
CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
const py::kwargs& kwargs,
const CallSignature& signature,
py::object out_and_fastpath_data) {
// We need to insert the element.
auto result = executables_.emplace(signature, std::make_unique<CacheEntry>());
auto it = result.first;
CacheEntry* cache_entry = it->second.get();
// CallSignatures in the cache own their keyword argument reference.
result.first->first.IncRef();
py::tuple tuple = py::cast<py::tuple>(out_and_fastpath_data);
CHECK_EQ(tuple.size(), 2);
if (tuple[1].is_none()) {
cache_entry->fall_back_to_python = true;
cache_entry->compilation_complete.Notify();
return cache_entry;
}
py::tuple executable_handlers_out_tree = py::cast<py::tuple>(tuple[1]);
if (executable_handlers_out_tree.size() != 5) {
throw std::runtime_error(absl::StrCat(
"The versions of jaxlib and Jax are incompatible (jaxlib is too recent "
"compared to Jax. Upgrade Jax is advised. The C++ code expects "
"5 arguments but ",
executable_handlers_out_tree.size(), " where provided: ",
py::cast<std::string>(
py::str(py::repr(executable_handlers_out_tree)))));
}
// (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
executable_handlers_out_tree[0]);
cache_entry->executable = std::move(executable);
int num_devices =
cache_entry->executable->pjrt_executable().addressable_devices().size();
// The presence of jit(pmap) is detected from Python.
CHECK_EQ(num_devices, 1);
auto out_tree = py::cast<PyTreeDef>(executable_handlers_out_tree[1]);
cache_entry->out_pytree_def = std::move(out_tree);
cache_entry->sticky_device =
py::cast<py::object>(executable_handlers_out_tree[2]);
auto avals = py::cast<py::list>(executable_handlers_out_tree[3]);
auto lazy_exprs = py::cast<py::list>(executable_handlers_out_tree[4]);
CHECK_EQ(avals.size(), lazy_exprs.size());
cache_entry->out_avals.reserve(avals.size());
cache_entry->out_lazy_exprs.reserve(avals.size());
for (int i = 0; i < avals.size(); ++i) {
py::object shaped_array = py::reinterpret_borrow<py::object>(avals[i]);
py::object lazy_expr = py::reinterpret_borrow<py::object>(lazy_exprs[i]);
cache_entry->out_avals.push_back(shaped_array);
CHECK(lazy_expr.is_none() || !IsTrivialLazyExpr(lazy_expr));
cache_entry->out_lazy_exprs.push_back(lazy_expr);
}
cache_entry->compilation_complete.Notify();
return cache_entry;
}
py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
if (always_fallback_to_python_) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
// Delayed values are retrieved on the first call to `Call`.
if (!default_device_) {
// As we are calling Python code, that may release the GIL, we first hold
// mu_ before holding the GIL.
py::gil_scoped_release gil_release;
{
absl::MutexLock lock1(&mu_);
py::gil_scoped_acquire gil_aquire;
jax_enable_x64_ = py::cast<bool>(get_jax_enable_x64_());
jax_disable_jit_ = py::cast<bool>(get_jax_disable_jit_());
if (!default_device_) {
py::object device_and_is_committed = get_device_();
try {
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
device_and_is_committed.attr("default_device"));
} catch (const py::cast_error& e) {
// Pathways and Cloud TPU 2VM runtime.
always_fallback_to_python_ = true;
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
default_pyclient_ = default_pydevice_.client;
default_device_ = default_pydevice_.contents;
if (!default_device_) { // UPTC
always_fallback_to_python_ = true;
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
is_committed_ =
py::cast<bool>(device_and_is_committed.attr("committed_to_device"));
}
}
}
CHECK(default_device_);
if (JitIsDisabled()) {
return fun_(*args, **kwargs);
}
ParsedArgumentsAsBuffers arguments;
if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
// The C++ jit do not support Tracers arguments inputs yet. The Python-based
// jit function will be called if any of the dynamic arguments is unsupported.
if (!ConvertArgsToBuffers(jax_enable_x64_.value(), *default_pyclient_,
default_device_, is_committed_, arguments)
.ok()) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
CacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature);
if (!cache_entry) {
py::object out_and_fastpath_data = cache_miss_(*args, **kwargs);
cache_entry = GetCacheEntryIfPresent(arguments.signature);
if (!cache_entry) {
cache_entry = AddCacheEntry(args, kwargs, arguments.signature,
out_and_fastpath_data);
}
CHECK(cache_entry);
if (cache_entry->fall_back_to_python) {
return py::cast<py::tuple>(out_and_fastpath_data)[0];
}
// As we have already computed the results, we can return it.
// It's even *required* e.g. if there are donated arguments, because
// otherwise the buffer which has been donated already will be invalid.
return py::cast<py::tuple>(out_and_fastpath_data)[0];
}
CHECK(cache_entry);
if (cache_entry->fall_back_to_python) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
std::vector<std::unique_ptr<xla::PyBuffer>> outputs =
ValueOrThrow(cache_entry->executable->PjRtExecute(arguments.arg_buffers));
const std::vector<py::object>& out_avals = cache_entry->out_avals;
const std::vector<py::object>& out_lazy_exprs = cache_entry->out_lazy_exprs;
const py::object& sticky_device = cache_entry->sticky_device;
py::list flat_device_arrays;
for (int i = 0; i < outputs.size(); ++i) {
auto& buffer = outputs[i];
if (out_lazy_exprs[i].is_none()) { // No LazyExpr.
buffer->SetAval(out_avals[i]);
buffer->SetStickyDevice(sticky_device);
flat_device_arrays.append(py::cast(std::move(outputs[i])));
} else {
static const auto* xla_module =
new py::module(py::module::import("jax.interpreters.xla"));
static const auto* device_array =
new py::handle(xla_module->attr("_DeviceArray"));
flat_device_arrays.append(
(*device_array)(out_avals[i], sticky_device, out_lazy_exprs[i],
py::cast(std::move(outputs[i]))));
}
}
return cache_entry->out_pytree_def.Unflatten(flat_device_arrays);
}
} // namespace
void BuildJaxjitSubmodule(pybind11::module& m) {
py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library");
py::class_<CompiledFunction, std::unique_ptr<CompiledFunction>> cfun(
jitlib, "CompiledFunction");
cfun.def("__call__", &CompiledFunction::Call);
cfun.def_property_readonly("__signature__", &CompiledFunction::__signature__);
jitlib.def("set_disable_jit", &SetDisableJit);
jitlib.def("get_disable_jit", &GetDisableJit);
jitlib.def(
"jit",
[](py::function fun, py::function cache_miss, py::function get_device,
py::function get_jax_enable_x64, py::function get_jax_disable_jit,
std::vector<int> static_argnums) -> std::unique_ptr<CompiledFunction> {
return std::make_unique<CompiledFunction>(
std::move(fun), std::move(cache_miss), std::move(get_device),
std::move(get_jax_enable_x64), std::move(get_jax_disable_jit),
std::move(static_argnums));
});
// This function is yet a full replacement for the Python one, because:
// (a) it does not support abstract types,
// (b) it does not set the device stickiness yet.
// TODO(jblespiau): Finish the replacement of the Python feature.
jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64,
ClientAndPtr<PjRtDevice> to_device) {
std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
StatusOr<DevicePutResult> results =
DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient);
if (!results.ok()) {
throw std::runtime_error(results.status().error_message());
}
if (results->owned_buffer) {
auto buffer = std::make_unique<PyBuffer>(
pyclient, std::move(results->owned_buffer), Traceback::Get());
static const auto* jax_core =
new py::module(py::module::import("jax.core"));
static const auto* shaped_array =
new py::handle(jax_core->attr("ShapedArray"));
buffer->SetAval((*shaped_array)(
buffer->python_shape(), buffer->python_dtype(), results->weak_type));
buffer->SetStickyDevice(py::none());
return py::cast(std::move(buffer));
} else {
return py::cast<py::object>(obj);
}
});
// All private members are only for testing purposes
cfun.def("_cache_size", &CompiledFunction::cache_size);
jitlib.def("_DtypeTo32BitDtype", [](const py::object obj) -> py::object {
py::dtype dtype = py::dtype::from_args(obj);
const py::dtype* res = DtypeTo32BitDtype(dtype);
if (res) {
return *res;
} else {
return py::none();
}
});
jitlib.def("_is_float0", &IsFloat0);
jitlib.def("_is_trivial", &IsTrivialLazyExpr);
}
} // namespace xla