Move jax_jit functions to the jax namespace.

PiperOrigin-RevId: 352208505
Change-Id: Iff1ac8e432f13d41f71b41b6e560c35765549043
This commit is contained in:
Jean-Baptiste Lespiau 2021-01-16 15:06:30 -08:00 committed by TensorFlower Gardener
parent 3b94e9cfdb
commit 10c06df949
3 changed files with 131 additions and 117 deletions

View File

@ -53,7 +53,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/status.h"
namespace xla {
namespace jax {
namespace py = pybind11;
@ -180,11 +180,11 @@ H AbslHashValue(H h, const CallSignature& s) {
// Filter out static arguments, flatten and concatenate other arguments (i.e.
// dynamic positional and keyword arguments), filling `arguments` in place.
Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments) {
xla::Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments) {
if (static_argnums.size() > args.size()) {
return InvalidArgument(
return xla::InvalidArgument(
"%s", "[jaxjit] Error with static argnums, executing the Python path.");
}
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
@ -196,7 +196,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
for (size_t i = 0; i < args.size(); ++i) {
if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
static_argnums.end()) {
PyTreeDef pytree_def;
xla::PyTreeDef pytree_def;
pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
} else {
@ -210,7 +210,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
py_kwargs.end());
// We first intern the keys, then sort them (by name, as in the Python path)
// (see also PyTreeDef::Flatten) and then create the signatures.
// (see also xla::PyTreeDef::Flatten) and then create the signatures.
// TODO(jblespiau): We should be able to sort the keys by interned-key
// pointers, but this requires the Python compilation to do the same.
arguments.signature.keyword_args.resize(kwargs.size());
@ -236,7 +236,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
kwargs[i].second, arguments.flat_dynamic_args);
}
return Status::OK();
return xla::Status::OK();
}
namespace {
@ -374,11 +374,11 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
namespace {
using ToArgSignatureHandler =
std::function<StatusOr<ArgSignature>(py::handle, bool)>;
std::function<xla::StatusOr<ArgSignature>(py::handle, bool)>;
}
StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
bool jax_enable_x64) {
xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
bool jax_enable_x64) {
static const absl::flat_hash_map<PyObject*,
ToArgSignatureHandler>* const handlers = [] {
auto p = new absl::flat_hash_map<PyObject*, ToArgSignatureHandler>();
@ -389,38 +389,40 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
// The 4 Python native types.
ToArgSignatureHandler bool_handler = [](py::handle,
bool) -> StatusOr<ArgSignature> {
return ArgSignature(PrimitiveType::PRED, {}, true);
ToArgSignatureHandler bool_handler =
[](py::handle, bool) -> xla::StatusOr<ArgSignature> {
return ArgSignature(xla::PrimitiveType::PRED, {}, true);
};
ToArgSignatureHandler int_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
if (jax_enable_x64) {
return ArgSignature(PrimitiveType::S64, {}, true);
return ArgSignature(xla::PrimitiveType::S64, {}, true);
} else {
return ArgSignature(PrimitiveType::S32, {}, true);
return ArgSignature(xla::PrimitiveType::S32, {}, true);
}
};
ToArgSignatureHandler float_handler =
[&dtypes](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
[&dtypes](py::handle h,
bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
// Only Python native types has a True weak_type.
bool weak_type = !py::isinstance(h, dtypes.np_float64);
if (jax_enable_x64) {
return ArgSignature(PrimitiveType::F64, {}, weak_type);
return ArgSignature(xla::PrimitiveType::F64, {}, weak_type);
} else {
return ArgSignature(PrimitiveType::F32, {}, weak_type);
return ArgSignature(xla::PrimitiveType::F32, {}, weak_type);
}
};
ToArgSignatureHandler complex_handler =
[&dtypes](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
[&dtypes](py::handle h,
bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
// Note that this branch is also taken for np.complex128:
// isinstance(np.complex128(3), complex) returns True
// isinstance(np.complex64(3), complex) returns False
bool weak_type = !py::isinstance(h, dtypes.np_complex128);
if (jax_enable_x64) {
return ArgSignature(PrimitiveType::C128, {}, weak_type);
return ArgSignature(xla::PrimitiveType::C128, {}, weak_type);
} else {
return ArgSignature(PrimitiveType::C64, {}, weak_type);
return ArgSignature(xla::PrimitiveType::C64, {}, weak_type);
}
};
@ -432,18 +434,19 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
// The Buffer types
// PyBuffer necessarily has a trivial LazyExpr, no need to check it.
ToArgSignatureHandler buffer_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
bool weak_type = py::cast<py::bool_>(h.attr("aval").attr("weak_type"));
return ArgSignature(buffer->buffer()->on_host_shape().element_type(),
buffer->buffer()->on_host_shape().dimensions(),
weak_type);
};
(*p)[py::type::handle_of<DeviceArrayBase>().ptr()] = buffer_handler;
(*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] = buffer_handler;
ToArgSignatureHandler device_array_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
py::handle aval = h.attr("aval");
TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(aval.attr("dtype")));
TF_ASSIGN_OR_RETURN(auto dtype,
xla::DtypeToPrimitiveType(aval.attr("dtype")));
return ArgSignature(dtype,
py::cast<std::vector<int64>>(aval.attr("shape")),
py::cast<py::bool_>(aval.attr("weak_type")));
@ -452,10 +455,10 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
(*p)[device_array.ptr()] = device_array_handler;
ToArgSignatureHandler numpy_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
py::array numpy_array = py::cast<py::array>(h);
if (IsFloat0(numpy_array)) {
return InvalidArgument(
return xla::InvalidArgument(
"float0 numpy arrays not supported in C++. "
"Falling back to Python.");
}
@ -463,11 +466,11 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
const py::dtype raw_dtype = numpy_array.dtype();
const py::dtype* to_dtype = DtypeTo32BitDtype(raw_dtype);
PrimitiveType dtype;
xla::PrimitiveType dtype;
if (to_dtype) {
TF_ASSIGN_OR_RETURN(dtype, DtypeToPrimitiveType(*to_dtype));
TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(*to_dtype));
} else {
TF_ASSIGN_OR_RETURN(dtype, DtypeToPrimitiveType(raw_dtype));
TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(raw_dtype));
}
// We need the reinterpret_cast for the OSS version to build.
return ArgSignature(dtype,
@ -477,7 +480,7 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
/*weak_type=*/false);
}
TF_ASSIGN_OR_RETURN(auto dtype,
DtypeToPrimitiveType(numpy_array.dtype()));
xla::DtypeToPrimitiveType(numpy_array.dtype()));
return ArgSignature(dtype,
absl::MakeConstSpan(reinterpret_cast<const int64*>(
numpy_array.shape()),
@ -489,27 +492,28 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
(*p)[ndarray.ptr()] = numpy_handler;
ToArgSignatureHandler np_uint64_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
if (jax_enable_x64) {
return ArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false);
return ArgSignature(xla::PrimitiveType::U64, {}, /*weak_type=*/false);
} else {
return ArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false);
return ArgSignature(xla::PrimitiveType::U32, {}, /*weak_type=*/false);
}
};
ToArgSignatureHandler np_int_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
if (jax_enable_x64) {
return ArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false);
return ArgSignature(xla::PrimitiveType::S64, {}, /*weak_type=*/false);
} else {
return ArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false);
return ArgSignature(xla::PrimitiveType::S32, {}, /*weak_type=*/false);
}
};
ToArgSignatureHandler numpy_array_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
// This block deals with all numpy scalar types, except for int64_dt,
// float64_dt and complex128_dt which are taken care of in previous if
// blocks.
TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(h.attr("dtype")));
TF_ASSIGN_OR_RETURN(auto dtype,
xla::DtypeToPrimitiveType(h.attr("dtype")));
return ArgSignature(dtype, {}, /*weak_type=*/false);
};
@ -545,7 +549,7 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
return res->second(arg, jax_enable_x64);
}
}
return InvalidArgument(
return xla::InvalidArgument(
"%s", absl::StrCat("Not supported: The C++ ToArgSignature only accepts "
"Buffer/DeviceArray/ShardedDeviceArray, Numpy "
"arrays scalars of supported types "
@ -557,17 +561,17 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
}
namespace {
using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
py::handle, PjRtDevice*, bool, xla::PyClient&)>;
using DevicePutFunc = std::function<xla::StatusOr<DevicePutResult>(
py::handle, xla::PjRtDevice*, bool, xla::PyClient&)>;
DevicePutResult HandleBool(py::handle h, PjRtDevice* to_device,
DevicePutResult HandleBool(py::handle h, xla::PjRtDevice* to_device,
bool jax_enable_x64, xla::PyClient& pyclient) {
return DevicePutResult(ConvertToScalarBuffer<bool, py::bool_>(
h, pyclient.pjrt_client(), to_device),
/*weak_type=*/true);
}
DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device,
DevicePutResult HandleInt(py::handle obj, xla::PjRtDevice* to_device,
bool jax_enable_x64, xla::PyClient& pyclient) {
if (jax_enable_x64) {
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
@ -581,9 +585,10 @@ DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device,
}
template <bool weak_type>
StatusOr<DevicePutResult> HandleFloat(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
xla::StatusOr<DevicePutResult> HandleFloat(py::handle h,
xla::PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
if (jax_enable_x64) {
return DevicePutResult(ConvertToScalarBuffer<double, py::float_>(
h, pyclient.pjrt_client(), to_device),
@ -596,9 +601,10 @@ StatusOr<DevicePutResult> HandleFloat(py::handle h, PjRtDevice* to_device,
}
template <bool weak_type>
StatusOr<DevicePutResult> HandleComplex(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
xla::StatusOr<DevicePutResult> HandleComplex(py::handle h,
xla::PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
// This branch is also taken for np.complex128:
// isinstance(np.complex128(3), complex) returns True
// isinstance(np.complex64(3), complex) returns False
@ -628,22 +634,22 @@ StatusOr<DevicePutResult> HandleComplex(py::handle h, PjRtDevice* to_device,
}
}
StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
xla::StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
xla::PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
return InvalidArgument(
return xla::InvalidArgument(
"Non-trivial lazy expression not supported in C++. "
"Falling back to Python.");
}
PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
// Same block as in the previous `if (is_py_buffer)`.
if (buffer->device().contents == to_device) {
return DevicePutResult(buffer->buffer(), weak_type);
} else {
std::unique_ptr<PjRtBuffer> copied_buffer =
std::unique_ptr<xla::PjRtBuffer> copied_buffer =
ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
return DevicePutResult(std::move(copied_buffer), weak_type);
}
@ -651,7 +657,7 @@ StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
// Do not convert types, and only call PjRtBufferFromPyval, independently
// of the value of jax_enable_x64.
DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device,
DevicePutResult HandleBufferFromPyval(py::handle h, xla::PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
std::unique_ptr<xla::PjRtBuffer> buffer =
@ -662,7 +668,7 @@ DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device,
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
}
DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device,
DevicePutResult HandleNpBool(py::handle h, xla::PjRtDevice* to_device,
bool jax_enable_x64, xla::PyClient& pyclient) {
if (jax_enable_x64) {
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
@ -675,7 +681,7 @@ DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device,
}
}
DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device,
DevicePutResult HandleUint64(py::handle h, xla::PjRtDevice* to_device,
bool jax_enable_x64, xla::PyClient& pyclient) {
if (jax_enable_x64) {
std::unique_ptr<xla::PjRtBuffer> buffer =
@ -698,14 +704,15 @@ DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device,
}
}
StatusOr<DevicePutResult> HandleNdarray(py::handle h, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
xla::StatusOr<DevicePutResult> HandleNdarray(py::handle h,
xla::PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
py::array numpy_array = py::cast<py::array>(h);
if (IsFloat0(numpy_array)) {
return InvalidArgument("%s",
"float0 numpy arrays not supported in C++. "
"Falling back to Python.");
return xla::InvalidArgument("%s",
"float0 numpy arrays not supported in C++. "
"Falling back to Python.");
}
// If jax_enable_x64 is not set, we need to coerce 32 bits types.
// Note that this is calling back to Python!
@ -727,9 +734,10 @@ StatusOr<DevicePutResult> HandleNdarray(py::handle h, PjRtDevice* to_device,
} // namespace
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
xla::PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient) {
static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
[] {
auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
@ -751,7 +759,8 @@ StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
const auto pxla_module = py::module::import("jax.interpreters.pxla");
const auto& sda = pxla_module.attr("ShardedDeviceArray");
(*p)[device_array.ptr()] = HandleDeviceArray;
(*p)[py::type::handle_of<DeviceArrayBase>().ptr()] = HandleDeviceArray;
(*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] =
HandleDeviceArray;
(*p)[sda.ptr()] = HandleBufferFromPyval;
// Numpy arrays.
(*p)[numpy.attr("ndarray").ptr()] = HandleNdarray;
@ -786,7 +795,7 @@ StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
return res->second(arg, to_device, jax_enable_x64, pyclient);
}
}
return InvalidArgument(
return xla::InvalidArgument(
"%s", absl::StrCat(
"Not supported: The C++ jax jit execution path, only accepts "
"DeviceArray, Numpy arrays scalars of supported types "
@ -801,7 +810,7 @@ namespace {
struct CacheEntry {
std::shared_ptr<xla::PyExecutable> executable;
PyTreeDef out_pytree_def;
xla::PyTreeDef out_pytree_def;
// We use Python types within the vector because this is what we will be
// returning to Python. No need to convert back and forth.
// We need py::object to maintain the objects alive.
@ -817,7 +826,7 @@ struct CacheEntry {
// a signature and if the object has been insterted already, other threads
// will wait for the notification.
absl::Notification compilation_complete;
absl::optional<Status> compilation_error = absl::nullopt;
absl::optional<xla::Status> compilation_error = absl::nullopt;
// Trivial computation will fallback to Python.
// Running a jax(pmap) will also fallback to Python.
bool fall_back_to_python = false;
@ -896,7 +905,7 @@ class CompiledFunction {
// the `default_device_` which will be used as the targeted device. In
// which case, we will always copy input buffers to this device.
std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
xla::ClientAndPtr<PjRtDevice> default_pydevice_;
xla::ClientAndPtr<xla::PjRtDevice> default_pydevice_;
xla::PjRtDevice* default_device_ = nullptr;
bool is_committed_;
};
@ -924,11 +933,12 @@ CompiledFunction::~CompiledFunction() {
// Converts flattened arguments contained in ParsedArgumentsAsBuffers in
// place. If arguments are `DeviceArray`, they must all be on the same `Device`.
//
// Returns `OkStatus()` on success. Returning an error should lead to calling
// the Python fallback.
Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
xla::PjRtDevice* default_device, bool is_committed,
ParsedArgumentsAsBuffers& arguments) {
// Returns `Okxla::Status()` on success. Returning an error should lead to
// calling the Python fallback.
xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
xla::PjRtDevice* default_device,
bool is_committed,
ParsedArgumentsAsBuffers& arguments) {
std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
auto& keep_alive = arguments.keep_alive;
@ -953,7 +963,8 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
for (py::handle arg : arguments.flat_dynamic_args) {
// We specically only deal with DeviceArray (not ShardedDeviceArray).
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
if (py::isinstance<xla::PyBuffer>(arg) ||
arg.get_type().is(device_array)) {
xla::PyBuffer* buffer;
if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
continue;
@ -962,7 +973,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
// This can fail, e.g. when device_buffer is a `DeviceConstant`.
buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
} catch (const py::cast_error& e) {
return InvalidArgument(
return xla::InvalidArgument(
"%s",
absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
"`device_buffer` field is of type ",
@ -995,7 +1006,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
TF_ASSIGN_OR_RETURN(DevicePutResult on_device,
DevicePut(arg, data_device, jax_enable_x64, pyclient));
PjRtBuffer* buffer = on_device.buffer;
xla::PjRtBuffer* buffer = on_device.buffer;
arg_buffers.push_back(buffer);
if (on_device.owned_buffer) {
keep_alive.emplace_back(std::move(on_device.owned_buffer));
@ -1005,7 +1016,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
buffer->on_host_shape().dimensions(), on_device.weak_type);
arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
}
return Status::OK();
return xla::Status::OK();
}
CacheEntry* CompiledFunction::GetCacheEntryIfPresent(
@ -1063,7 +1074,7 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
// The presence of jit(pmap) is detected from Python.
CHECK_EQ(num_devices, 1);
auto out_tree = py::cast<PyTreeDef>(executable_handlers_out_tree[1]);
auto out_tree = py::cast<xla::PyTreeDef>(executable_handlers_out_tree[1]);
cache_entry->out_pytree_def = std::move(out_tree);
cache_entry->sticky_device =
@ -1105,7 +1116,7 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
if (!default_device_) {
py::object device_and_is_committed = get_device_();
try {
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
default_pydevice_ = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
device_and_is_committed.attr("default_device"));
} catch (const py::cast_error& e) {
// Pathways and Cloud TPU 2VM runtime.
@ -1217,16 +1228,16 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
// (b) it does not set the device stickiness yet.
// TODO(jblespiau): Finish the replacement of the Python feature.
jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64,
ClientAndPtr<PjRtDevice> to_device) {
xla::ClientAndPtr<xla::PjRtDevice> to_device) {
std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
StatusOr<DevicePutResult> results =
xla::StatusOr<DevicePutResult> results =
DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient);
if (!results.ok()) {
throw std::runtime_error(results.status().error_message());
}
if (results->owned_buffer) {
auto buffer = std::make_unique<PyBuffer>(
pyclient, std::move(results->owned_buffer), Traceback::Get());
auto buffer = std::make_unique<xla::PyBuffer>(
pyclient, std::move(results->owned_buffer), xla::Traceback::Get());
static const auto* jax_core =
new py::module(py::module::import("jax.core"));
@ -1248,9 +1259,10 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
[](const ArgSignature& sig) {
return PrimitiveTypeToDtype(sig.dtype);
})
.def_property_readonly(
"shape",
[](const ArgSignature& sig) { return IntSpanToTuple(sig.shape); })
.def_property_readonly("shape",
[](const ArgSignature& sig) {
return xla::IntSpanToTuple(sig.shape);
})
.def_readonly("weak_type", &ArgSignature::weak_type);
jitlib.def("_ArgSignatureOfValue", &ArgSignatureOfValue);
@ -1269,4 +1281,4 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
jitlib.def("_is_trivial", &IsTrivialLazyExpr);
}
} // namespace xla
} // namespace jax

View File

@ -24,15 +24,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/pytree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace jax {
// Describes the abstract shape and dtype of an argument.
struct ArgSignature {
ArgSignature(PrimitiveType dtype, absl::Span<const int64> shape,
ArgSignature(xla::PrimitiveType dtype, absl::Span<const int64> shape,
bool weak_type)
: dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {}
// This is the XLA dtype of the object.
const PrimitiveType dtype;
const xla::PrimitiveType dtype;
const absl::InlinedVector<int64, 4> shape;
// JAX arguments can be of weak type, if and only if they are Python scalars
// or `DeviceArray` values such that `aval.weak_type` is true.
@ -66,7 +66,7 @@ struct CallSignature {
// To avoid comparing strings, we intern the kwargs strings.
// The compilation cache holds a reference to all the keys.
pybind11::handle key;
PyTreeDef value_treedef;
xla::PyTreeDef value_treedef;
bool operator==(const KwargEntry& other) const {
return key.ptr() == other.key.ptr() &&
value_treedef == other.value_treedef;
@ -78,13 +78,13 @@ struct CallSignature {
// order of their argnum index.
std::vector<pybind11::object> static_args;
// A PyTreeDef for each positional dynamic (i.e. not static) argument.
std::vector<PyTreeDef> dynamic_positional_args_treedef;
std::vector<xla::PyTreeDef> dynamic_positional_args_treedef;
// Keyword arguments. Sorted by the keyword name.
std::vector<KwargEntry> keyword_args;
// Shape and dtype for both the dynamic positional arguments and the keyword
// arguments (sorted by keyword name).
std::vector<ArgSignature> dynamic_args_signatures;
PjRtDevice* device;
xla::PjRtDevice* device;
bool operator==(const CallSignature& other) const;
bool operator!=(const CallSignature& other) const {
@ -132,28 +132,28 @@ struct ParsedArgumentsAsBuffers {
// Filter out static arguments, flatten and concatenate other arguments (i.e.
// dynamic positional and keyword arguments), filling `arguments` in place.
Status ParseArguments(const pybind11::args& args,
const pybind11::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments);
xla::Status ParseArguments(const pybind11::args& args,
const pybind11::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments);
struct DevicePutResult {
explicit DevicePutResult(PjRtBuffer* b, bool weak_type)
explicit DevicePutResult(xla::PjRtBuffer* b, bool weak_type)
: buffer(b), weak_type(weak_type), owned_buffer(nullptr) {}
DevicePutResult(std::unique_ptr<PjRtBuffer> new_buffer, bool weak_type)
DevicePutResult(std::unique_ptr<xla::PjRtBuffer> new_buffer, bool weak_type)
: buffer(new_buffer.get()),
weak_type(weak_type),
owned_buffer(std::move(new_buffer)) {}
PjRtBuffer* buffer;
xla::PjRtBuffer* buffer;
bool weak_type;
std::unique_ptr<PjRtBuffer> owned_buffer;
std::unique_ptr<xla::PjRtBuffer> owned_buffer;
};
// Returns the ArgSignature associated with an argument. Returns an error if
// the argument is not supported.
StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
bool jax_enable_x64);
xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
bool jax_enable_x64);
// Moves a device-like object to be on device.
// - If the object is already on device, `owned_buffer` will be nullptr.
@ -161,15 +161,17 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
// `owned_buffer`.
// In all cases, `buffer` will point to the already existing or newly created
// buffer.
// If `obj` is not convertible to a `PjRtBuffer` from C++, an error will be
// If `obj` is not convertible to a `xla::PjRtBuffer` from C++, an error will be
// returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not
// supported yet.
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
bool jax_enable_x64, PyClient& pyclient);
xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
xla::PjRtDevice* to_device,
bool jax_enable_x64,
xla::PyClient& pyclient);
// The function to call in `xla.cc` to add the bindings for this module.
void BuildJaxjitSubmodule(pybind11::module& m);
} // namespace xla
} // namespace jax
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_

View File

@ -415,7 +415,7 @@ PYBIND11_MODULE(xla_extension, m) {
BuildOpsSubmodule(&m);
BuildOutfeedReceiverSubmodule(&m);
BuildPytreeSubmodule(m);
BuildJaxjitSubmodule(m);
jax::BuildJaxjitSubmodule(m);
jax::BuildPmapSubmodule(m);
BuildTracebackSubmodule(m);