Move jax_jit functions to the jax
namespace.
PiperOrigin-RevId: 352208505 Change-Id: Iff1ac8e432f13d41f71b41b6e560c35765549043
This commit is contained in:
parent
3b94e9cfdb
commit
10c06df949
@ -53,7 +53,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace xla {
|
||||
namespace jax {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
@ -180,11 +180,11 @@ H AbslHashValue(H h, const CallSignature& s) {
|
||||
|
||||
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
||||
// dynamic positional and keyword arguments), filling `arguments` in place.
|
||||
Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
absl::Span<int const> static_argnums,
|
||||
ParsedArgumentsAsBuffers& arguments) {
|
||||
xla::Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
absl::Span<int const> static_argnums,
|
||||
ParsedArgumentsAsBuffers& arguments) {
|
||||
if (static_argnums.size() > args.size()) {
|
||||
return InvalidArgument(
|
||||
return xla::InvalidArgument(
|
||||
"%s", "[jaxjit] Error with static argnums, executing the Python path.");
|
||||
}
|
||||
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
|
||||
@ -196,7 +196,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
|
||||
static_argnums.end()) {
|
||||
PyTreeDef pytree_def;
|
||||
xla::PyTreeDef pytree_def;
|
||||
pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
|
||||
arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
|
||||
} else {
|
||||
@ -210,7 +210,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
|
||||
py_kwargs.end());
|
||||
// We first intern the keys, then sort them (by name, as in the Python path)
|
||||
// (see also PyTreeDef::Flatten) and then create the signatures.
|
||||
// (see also xla::PyTreeDef::Flatten) and then create the signatures.
|
||||
// TODO(jblespiau): We should be able to sort the keys by interned-key
|
||||
// pointers, but this requires the Python compilation to do the same.
|
||||
arguments.signature.keyword_args.resize(kwargs.size());
|
||||
@ -236,7 +236,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
|
||||
kwargs[i].second, arguments.flat_dynamic_args);
|
||||
}
|
||||
return Status::OK();
|
||||
return xla::Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -374,11 +374,11 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
|
||||
namespace {
|
||||
|
||||
using ToArgSignatureHandler =
|
||||
std::function<StatusOr<ArgSignature>(py::handle, bool)>;
|
||||
std::function<xla::StatusOr<ArgSignature>(py::handle, bool)>;
|
||||
}
|
||||
|
||||
StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
bool jax_enable_x64) {
|
||||
xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
bool jax_enable_x64) {
|
||||
static const absl::flat_hash_map<PyObject*,
|
||||
ToArgSignatureHandler>* const handlers = [] {
|
||||
auto p = new absl::flat_hash_map<PyObject*, ToArgSignatureHandler>();
|
||||
@ -389,38 +389,40 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
|
||||
|
||||
// The 4 Python native types.
|
||||
ToArgSignatureHandler bool_handler = [](py::handle,
|
||||
bool) -> StatusOr<ArgSignature> {
|
||||
return ArgSignature(PrimitiveType::PRED, {}, true);
|
||||
ToArgSignatureHandler bool_handler =
|
||||
[](py::handle, bool) -> xla::StatusOr<ArgSignature> {
|
||||
return ArgSignature(xla::PrimitiveType::PRED, {}, true);
|
||||
};
|
||||
ToArgSignatureHandler int_handler =
|
||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
if (jax_enable_x64) {
|
||||
return ArgSignature(PrimitiveType::S64, {}, true);
|
||||
return ArgSignature(xla::PrimitiveType::S64, {}, true);
|
||||
} else {
|
||||
return ArgSignature(PrimitiveType::S32, {}, true);
|
||||
return ArgSignature(xla::PrimitiveType::S32, {}, true);
|
||||
}
|
||||
};
|
||||
ToArgSignatureHandler float_handler =
|
||||
[&dtypes](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
[&dtypes](py::handle h,
|
||||
bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
// Only Python native types has a True weak_type.
|
||||
bool weak_type = !py::isinstance(h, dtypes.np_float64);
|
||||
if (jax_enable_x64) {
|
||||
return ArgSignature(PrimitiveType::F64, {}, weak_type);
|
||||
return ArgSignature(xla::PrimitiveType::F64, {}, weak_type);
|
||||
} else {
|
||||
return ArgSignature(PrimitiveType::F32, {}, weak_type);
|
||||
return ArgSignature(xla::PrimitiveType::F32, {}, weak_type);
|
||||
}
|
||||
};
|
||||
ToArgSignatureHandler complex_handler =
|
||||
[&dtypes](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
[&dtypes](py::handle h,
|
||||
bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
// Note that this branch is also taken for np.complex128:
|
||||
// isinstance(np.complex128(3), complex) returns True
|
||||
// isinstance(np.complex64(3), complex) returns False
|
||||
bool weak_type = !py::isinstance(h, dtypes.np_complex128);
|
||||
if (jax_enable_x64) {
|
||||
return ArgSignature(PrimitiveType::C128, {}, weak_type);
|
||||
return ArgSignature(xla::PrimitiveType::C128, {}, weak_type);
|
||||
} else {
|
||||
return ArgSignature(PrimitiveType::C64, {}, weak_type);
|
||||
return ArgSignature(xla::PrimitiveType::C64, {}, weak_type);
|
||||
}
|
||||
};
|
||||
|
||||
@ -432,18 +434,19 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
// The Buffer types
|
||||
// PyBuffer necessarily has a trivial LazyExpr, no need to check it.
|
||||
ToArgSignatureHandler buffer_handler =
|
||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
|
||||
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
|
||||
bool weak_type = py::cast<py::bool_>(h.attr("aval").attr("weak_type"));
|
||||
return ArgSignature(buffer->buffer()->on_host_shape().element_type(),
|
||||
buffer->buffer()->on_host_shape().dimensions(),
|
||||
weak_type);
|
||||
};
|
||||
(*p)[py::type::handle_of<DeviceArrayBase>().ptr()] = buffer_handler;
|
||||
(*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] = buffer_handler;
|
||||
ToArgSignatureHandler device_array_handler =
|
||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
py::handle aval = h.attr("aval");
|
||||
TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(aval.attr("dtype")));
|
||||
TF_ASSIGN_OR_RETURN(auto dtype,
|
||||
xla::DtypeToPrimitiveType(aval.attr("dtype")));
|
||||
return ArgSignature(dtype,
|
||||
py::cast<std::vector<int64>>(aval.attr("shape")),
|
||||
py::cast<py::bool_>(aval.attr("weak_type")));
|
||||
@ -452,10 +455,10 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
(*p)[device_array.ptr()] = device_array_handler;
|
||||
|
||||
ToArgSignatureHandler numpy_handler =
|
||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
py::array numpy_array = py::cast<py::array>(h);
|
||||
if (IsFloat0(numpy_array)) {
|
||||
return InvalidArgument(
|
||||
return xla::InvalidArgument(
|
||||
"float0 numpy arrays not supported in C++. "
|
||||
"Falling back to Python.");
|
||||
}
|
||||
@ -463,11 +466,11 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
const py::dtype raw_dtype = numpy_array.dtype();
|
||||
const py::dtype* to_dtype = DtypeTo32BitDtype(raw_dtype);
|
||||
|
||||
PrimitiveType dtype;
|
||||
xla::PrimitiveType dtype;
|
||||
if (to_dtype) {
|
||||
TF_ASSIGN_OR_RETURN(dtype, DtypeToPrimitiveType(*to_dtype));
|
||||
TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(*to_dtype));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(dtype, DtypeToPrimitiveType(raw_dtype));
|
||||
TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(raw_dtype));
|
||||
}
|
||||
// We need the reinterpret_cast for the OSS version to build.
|
||||
return ArgSignature(dtype,
|
||||
@ -477,7 +480,7 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
/*weak_type=*/false);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto dtype,
|
||||
DtypeToPrimitiveType(numpy_array.dtype()));
|
||||
xla::DtypeToPrimitiveType(numpy_array.dtype()));
|
||||
return ArgSignature(dtype,
|
||||
absl::MakeConstSpan(reinterpret_cast<const int64*>(
|
||||
numpy_array.shape()),
|
||||
@ -489,27 +492,28 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
(*p)[ndarray.ptr()] = numpy_handler;
|
||||
|
||||
ToArgSignatureHandler np_uint64_handler =
|
||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
if (jax_enable_x64) {
|
||||
return ArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false);
|
||||
return ArgSignature(xla::PrimitiveType::U64, {}, /*weak_type=*/false);
|
||||
} else {
|
||||
return ArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false);
|
||||
return ArgSignature(xla::PrimitiveType::U32, {}, /*weak_type=*/false);
|
||||
}
|
||||
};
|
||||
ToArgSignatureHandler np_int_handler =
|
||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
if (jax_enable_x64) {
|
||||
return ArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false);
|
||||
return ArgSignature(xla::PrimitiveType::S64, {}, /*weak_type=*/false);
|
||||
} else {
|
||||
return ArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false);
|
||||
return ArgSignature(xla::PrimitiveType::S32, {}, /*weak_type=*/false);
|
||||
}
|
||||
};
|
||||
ToArgSignatureHandler numpy_array_handler =
|
||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
||||
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
// This block deals with all numpy scalar types, except for int64_dt,
|
||||
// float64_dt and complex128_dt which are taken care of in previous if
|
||||
// blocks.
|
||||
TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(h.attr("dtype")));
|
||||
TF_ASSIGN_OR_RETURN(auto dtype,
|
||||
xla::DtypeToPrimitiveType(h.attr("dtype")));
|
||||
return ArgSignature(dtype, {}, /*weak_type=*/false);
|
||||
};
|
||||
|
||||
@ -545,7 +549,7 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
return res->second(arg, jax_enable_x64);
|
||||
}
|
||||
}
|
||||
return InvalidArgument(
|
||||
return xla::InvalidArgument(
|
||||
"%s", absl::StrCat("Not supported: The C++ ToArgSignature only accepts "
|
||||
"Buffer/DeviceArray/ShardedDeviceArray, Numpy "
|
||||
"arrays scalars of supported types "
|
||||
@ -557,17 +561,17 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
}
|
||||
|
||||
namespace {
|
||||
using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
|
||||
py::handle, PjRtDevice*, bool, xla::PyClient&)>;
|
||||
using DevicePutFunc = std::function<xla::StatusOr<DevicePutResult>(
|
||||
py::handle, xla::PjRtDevice*, bool, xla::PyClient&)>;
|
||||
|
||||
DevicePutResult HandleBool(py::handle h, PjRtDevice* to_device,
|
||||
DevicePutResult HandleBool(py::handle h, xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||
return DevicePutResult(ConvertToScalarBuffer<bool, py::bool_>(
|
||||
h, pyclient.pjrt_client(), to_device),
|
||||
/*weak_type=*/true);
|
||||
}
|
||||
|
||||
DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device,
|
||||
DevicePutResult HandleInt(py::handle obj, xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||
if (jax_enable_x64) {
|
||||
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
|
||||
@ -581,9 +585,10 @@ DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device,
|
||||
}
|
||||
|
||||
template <bool weak_type>
|
||||
StatusOr<DevicePutResult> HandleFloat(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
xla::StatusOr<DevicePutResult> HandleFloat(py::handle h,
|
||||
xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
if (jax_enable_x64) {
|
||||
return DevicePutResult(ConvertToScalarBuffer<double, py::float_>(
|
||||
h, pyclient.pjrt_client(), to_device),
|
||||
@ -596,9 +601,10 @@ StatusOr<DevicePutResult> HandleFloat(py::handle h, PjRtDevice* to_device,
|
||||
}
|
||||
|
||||
template <bool weak_type>
|
||||
StatusOr<DevicePutResult> HandleComplex(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
xla::StatusOr<DevicePutResult> HandleComplex(py::handle h,
|
||||
xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
// This branch is also taken for np.complex128:
|
||||
// isinstance(np.complex128(3), complex) returns True
|
||||
// isinstance(np.complex64(3), complex) returns False
|
||||
@ -628,22 +634,22 @@ StatusOr<DevicePutResult> HandleComplex(py::handle h, PjRtDevice* to_device,
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
|
||||
PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
xla::StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
|
||||
xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
|
||||
return InvalidArgument(
|
||||
return xla::InvalidArgument(
|
||||
"Non-trivial lazy expression not supported in C++. "
|
||||
"Falling back to Python.");
|
||||
}
|
||||
PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
|
||||
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
|
||||
bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
|
||||
// Same block as in the previous `if (is_py_buffer)`.
|
||||
if (buffer->device().contents == to_device) {
|
||||
return DevicePutResult(buffer->buffer(), weak_type);
|
||||
} else {
|
||||
std::unique_ptr<PjRtBuffer> copied_buffer =
|
||||
std::unique_ptr<xla::PjRtBuffer> copied_buffer =
|
||||
ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
|
||||
return DevicePutResult(std::move(copied_buffer), weak_type);
|
||||
}
|
||||
@ -651,7 +657,7 @@ StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
|
||||
|
||||
// Do not convert types, and only call PjRtBufferFromPyval, independently
|
||||
// of the value of jax_enable_x64.
|
||||
DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device,
|
||||
DevicePutResult HandleBufferFromPyval(py::handle h, xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
std::unique_ptr<xla::PjRtBuffer> buffer =
|
||||
@ -662,7 +668,7 @@ DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device,
|
||||
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
|
||||
}
|
||||
|
||||
DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device,
|
||||
DevicePutResult HandleNpBool(py::handle h, xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||
if (jax_enable_x64) {
|
||||
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
|
||||
@ -675,7 +681,7 @@ DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device,
|
||||
}
|
||||
}
|
||||
|
||||
DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device,
|
||||
DevicePutResult HandleUint64(py::handle h, xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||
if (jax_enable_x64) {
|
||||
std::unique_ptr<xla::PjRtBuffer> buffer =
|
||||
@ -698,14 +704,15 @@ DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device,
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<DevicePutResult> HandleNdarray(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
xla::StatusOr<DevicePutResult> HandleNdarray(py::handle h,
|
||||
xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
py::array numpy_array = py::cast<py::array>(h);
|
||||
if (IsFloat0(numpy_array)) {
|
||||
return InvalidArgument("%s",
|
||||
"float0 numpy arrays not supported in C++. "
|
||||
"Falling back to Python.");
|
||||
return xla::InvalidArgument("%s",
|
||||
"float0 numpy arrays not supported in C++. "
|
||||
"Falling back to Python.");
|
||||
}
|
||||
// If jax_enable_x64 is not set, we need to coerce 32 bits types.
|
||||
// Note that this is calling back to Python!
|
||||
@ -727,9 +734,10 @@ StatusOr<DevicePutResult> HandleNdarray(py::handle h, PjRtDevice* to_device,
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
|
||||
xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
|
||||
[] {
|
||||
auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
|
||||
@ -751,7 +759,8 @@ StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
||||
const auto pxla_module = py::module::import("jax.interpreters.pxla");
|
||||
const auto& sda = pxla_module.attr("ShardedDeviceArray");
|
||||
(*p)[device_array.ptr()] = HandleDeviceArray;
|
||||
(*p)[py::type::handle_of<DeviceArrayBase>().ptr()] = HandleDeviceArray;
|
||||
(*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] =
|
||||
HandleDeviceArray;
|
||||
(*p)[sda.ptr()] = HandleBufferFromPyval;
|
||||
// Numpy arrays.
|
||||
(*p)[numpy.attr("ndarray").ptr()] = HandleNdarray;
|
||||
@ -786,7 +795,7 @@ StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
||||
return res->second(arg, to_device, jax_enable_x64, pyclient);
|
||||
}
|
||||
}
|
||||
return InvalidArgument(
|
||||
return xla::InvalidArgument(
|
||||
"%s", absl::StrCat(
|
||||
"Not supported: The C++ jax jit execution path, only accepts "
|
||||
"DeviceArray, Numpy arrays scalars of supported types "
|
||||
@ -801,7 +810,7 @@ namespace {
|
||||
|
||||
struct CacheEntry {
|
||||
std::shared_ptr<xla::PyExecutable> executable;
|
||||
PyTreeDef out_pytree_def;
|
||||
xla::PyTreeDef out_pytree_def;
|
||||
// We use Python types within the vector because this is what we will be
|
||||
// returning to Python. No need to convert back and forth.
|
||||
// We need py::object to maintain the objects alive.
|
||||
@ -817,7 +826,7 @@ struct CacheEntry {
|
||||
// a signature and if the object has been insterted already, other threads
|
||||
// will wait for the notification.
|
||||
absl::Notification compilation_complete;
|
||||
absl::optional<Status> compilation_error = absl::nullopt;
|
||||
absl::optional<xla::Status> compilation_error = absl::nullopt;
|
||||
// Trivial computation will fallback to Python.
|
||||
// Running a jax(pmap) will also fallback to Python.
|
||||
bool fall_back_to_python = false;
|
||||
@ -896,7 +905,7 @@ class CompiledFunction {
|
||||
// the `default_device_` which will be used as the targeted device. In
|
||||
// which case, we will always copy input buffers to this device.
|
||||
std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
|
||||
xla::ClientAndPtr<PjRtDevice> default_pydevice_;
|
||||
xla::ClientAndPtr<xla::PjRtDevice> default_pydevice_;
|
||||
xla::PjRtDevice* default_device_ = nullptr;
|
||||
bool is_committed_;
|
||||
};
|
||||
@ -924,11 +933,12 @@ CompiledFunction::~CompiledFunction() {
|
||||
// Converts flattened arguments contained in ParsedArgumentsAsBuffers in
|
||||
// place. If arguments are `DeviceArray`, they must all be on the same `Device`.
|
||||
//
|
||||
// Returns `OkStatus()` on success. Returning an error should lead to calling
|
||||
// the Python fallback.
|
||||
Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
xla::PjRtDevice* default_device, bool is_committed,
|
||||
ParsedArgumentsAsBuffers& arguments) {
|
||||
// Returns `Okxla::Status()` on success. Returning an error should lead to
|
||||
// calling the Python fallback.
|
||||
xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
xla::PjRtDevice* default_device,
|
||||
bool is_committed,
|
||||
ParsedArgumentsAsBuffers& arguments) {
|
||||
std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
|
||||
auto& keep_alive = arguments.keep_alive;
|
||||
|
||||
@ -953,7 +963,8 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
for (py::handle arg : arguments.flat_dynamic_args) {
|
||||
// We specically only deal with DeviceArray (not ShardedDeviceArray).
|
||||
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
|
||||
if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
|
||||
if (py::isinstance<xla::PyBuffer>(arg) ||
|
||||
arg.get_type().is(device_array)) {
|
||||
xla::PyBuffer* buffer;
|
||||
if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
|
||||
continue;
|
||||
@ -962,7 +973,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
// This can fail, e.g. when device_buffer is a `DeviceConstant`.
|
||||
buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
|
||||
} catch (const py::cast_error& e) {
|
||||
return InvalidArgument(
|
||||
return xla::InvalidArgument(
|
||||
"%s",
|
||||
absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
|
||||
"`device_buffer` field is of type ",
|
||||
@ -995,7 +1006,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
TF_ASSIGN_OR_RETURN(DevicePutResult on_device,
|
||||
DevicePut(arg, data_device, jax_enable_x64, pyclient));
|
||||
|
||||
PjRtBuffer* buffer = on_device.buffer;
|
||||
xla::PjRtBuffer* buffer = on_device.buffer;
|
||||
arg_buffers.push_back(buffer);
|
||||
if (on_device.owned_buffer) {
|
||||
keep_alive.emplace_back(std::move(on_device.owned_buffer));
|
||||
@ -1005,7 +1016,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
buffer->on_host_shape().dimensions(), on_device.weak_type);
|
||||
arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
|
||||
}
|
||||
return Status::OK();
|
||||
return xla::Status::OK();
|
||||
}
|
||||
|
||||
CacheEntry* CompiledFunction::GetCacheEntryIfPresent(
|
||||
@ -1063,7 +1074,7 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
|
||||
// The presence of jit(pmap) is detected from Python.
|
||||
CHECK_EQ(num_devices, 1);
|
||||
|
||||
auto out_tree = py::cast<PyTreeDef>(executable_handlers_out_tree[1]);
|
||||
auto out_tree = py::cast<xla::PyTreeDef>(executable_handlers_out_tree[1]);
|
||||
cache_entry->out_pytree_def = std::move(out_tree);
|
||||
|
||||
cache_entry->sticky_device =
|
||||
@ -1105,7 +1116,7 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
if (!default_device_) {
|
||||
py::object device_and_is_committed = get_device_();
|
||||
try {
|
||||
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
|
||||
default_pydevice_ = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
|
||||
device_and_is_committed.attr("default_device"));
|
||||
} catch (const py::cast_error& e) {
|
||||
// Pathways and Cloud TPU 2VM runtime.
|
||||
@ -1217,16 +1228,16 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
||||
// (b) it does not set the device stickiness yet.
|
||||
// TODO(jblespiau): Finish the replacement of the Python feature.
|
||||
jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64,
|
||||
ClientAndPtr<PjRtDevice> to_device) {
|
||||
xla::ClientAndPtr<xla::PjRtDevice> to_device) {
|
||||
std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
|
||||
StatusOr<DevicePutResult> results =
|
||||
xla::StatusOr<DevicePutResult> results =
|
||||
DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient);
|
||||
if (!results.ok()) {
|
||||
throw std::runtime_error(results.status().error_message());
|
||||
}
|
||||
if (results->owned_buffer) {
|
||||
auto buffer = std::make_unique<PyBuffer>(
|
||||
pyclient, std::move(results->owned_buffer), Traceback::Get());
|
||||
auto buffer = std::make_unique<xla::PyBuffer>(
|
||||
pyclient, std::move(results->owned_buffer), xla::Traceback::Get());
|
||||
|
||||
static const auto* jax_core =
|
||||
new py::module(py::module::import("jax.core"));
|
||||
@ -1248,9 +1259,10 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
||||
[](const ArgSignature& sig) {
|
||||
return PrimitiveTypeToDtype(sig.dtype);
|
||||
})
|
||||
.def_property_readonly(
|
||||
"shape",
|
||||
[](const ArgSignature& sig) { return IntSpanToTuple(sig.shape); })
|
||||
.def_property_readonly("shape",
|
||||
[](const ArgSignature& sig) {
|
||||
return xla::IntSpanToTuple(sig.shape);
|
||||
})
|
||||
.def_readonly("weak_type", &ArgSignature::weak_type);
|
||||
jitlib.def("_ArgSignatureOfValue", &ArgSignatureOfValue);
|
||||
|
||||
@ -1269,4 +1281,4 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
||||
jitlib.def("_is_trivial", &IsTrivialLazyExpr);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
} // namespace jax
|
||||
|
@ -24,15 +24,15 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/pytree.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace jax {
|
||||
|
||||
// Describes the abstract shape and dtype of an argument.
|
||||
struct ArgSignature {
|
||||
ArgSignature(PrimitiveType dtype, absl::Span<const int64> shape,
|
||||
ArgSignature(xla::PrimitiveType dtype, absl::Span<const int64> shape,
|
||||
bool weak_type)
|
||||
: dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {}
|
||||
// This is the XLA dtype of the object.
|
||||
const PrimitiveType dtype;
|
||||
const xla::PrimitiveType dtype;
|
||||
const absl::InlinedVector<int64, 4> shape;
|
||||
// JAX arguments can be of weak type, if and only if they are Python scalars
|
||||
// or `DeviceArray` values such that `aval.weak_type` is true.
|
||||
@ -66,7 +66,7 @@ struct CallSignature {
|
||||
// To avoid comparing strings, we intern the kwargs strings.
|
||||
// The compilation cache holds a reference to all the keys.
|
||||
pybind11::handle key;
|
||||
PyTreeDef value_treedef;
|
||||
xla::PyTreeDef value_treedef;
|
||||
bool operator==(const KwargEntry& other) const {
|
||||
return key.ptr() == other.key.ptr() &&
|
||||
value_treedef == other.value_treedef;
|
||||
@ -78,13 +78,13 @@ struct CallSignature {
|
||||
// order of their argnum index.
|
||||
std::vector<pybind11::object> static_args;
|
||||
// A PyTreeDef for each positional dynamic (i.e. not static) argument.
|
||||
std::vector<PyTreeDef> dynamic_positional_args_treedef;
|
||||
std::vector<xla::PyTreeDef> dynamic_positional_args_treedef;
|
||||
// Keyword arguments. Sorted by the keyword name.
|
||||
std::vector<KwargEntry> keyword_args;
|
||||
// Shape and dtype for both the dynamic positional arguments and the keyword
|
||||
// arguments (sorted by keyword name).
|
||||
std::vector<ArgSignature> dynamic_args_signatures;
|
||||
PjRtDevice* device;
|
||||
xla::PjRtDevice* device;
|
||||
|
||||
bool operator==(const CallSignature& other) const;
|
||||
bool operator!=(const CallSignature& other) const {
|
||||
@ -132,28 +132,28 @@ struct ParsedArgumentsAsBuffers {
|
||||
|
||||
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
||||
// dynamic positional and keyword arguments), filling `arguments` in place.
|
||||
Status ParseArguments(const pybind11::args& args,
|
||||
const pybind11::kwargs& py_kwargs,
|
||||
absl::Span<int const> static_argnums,
|
||||
ParsedArgumentsAsBuffers& arguments);
|
||||
xla::Status ParseArguments(const pybind11::args& args,
|
||||
const pybind11::kwargs& py_kwargs,
|
||||
absl::Span<int const> static_argnums,
|
||||
ParsedArgumentsAsBuffers& arguments);
|
||||
|
||||
struct DevicePutResult {
|
||||
explicit DevicePutResult(PjRtBuffer* b, bool weak_type)
|
||||
explicit DevicePutResult(xla::PjRtBuffer* b, bool weak_type)
|
||||
: buffer(b), weak_type(weak_type), owned_buffer(nullptr) {}
|
||||
DevicePutResult(std::unique_ptr<PjRtBuffer> new_buffer, bool weak_type)
|
||||
DevicePutResult(std::unique_ptr<xla::PjRtBuffer> new_buffer, bool weak_type)
|
||||
: buffer(new_buffer.get()),
|
||||
weak_type(weak_type),
|
||||
owned_buffer(std::move(new_buffer)) {}
|
||||
|
||||
PjRtBuffer* buffer;
|
||||
xla::PjRtBuffer* buffer;
|
||||
bool weak_type;
|
||||
std::unique_ptr<PjRtBuffer> owned_buffer;
|
||||
std::unique_ptr<xla::PjRtBuffer> owned_buffer;
|
||||
};
|
||||
|
||||
// Returns the ArgSignature associated with an argument. Returns an error if
|
||||
// the argument is not supported.
|
||||
StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
bool jax_enable_x64);
|
||||
xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
bool jax_enable_x64);
|
||||
|
||||
// Moves a device-like object to be on device.
|
||||
// - If the object is already on device, `owned_buffer` will be nullptr.
|
||||
@ -161,15 +161,17 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
// `owned_buffer`.
|
||||
// In all cases, `buffer` will point to the already existing or newly created
|
||||
// buffer.
|
||||
// If `obj` is not convertible to a `PjRtBuffer` from C++, an error will be
|
||||
// If `obj` is not convertible to a `xla::PjRtBuffer` from C++, an error will be
|
||||
// returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not
|
||||
// supported yet.
|
||||
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
||||
bool jax_enable_x64, PyClient& pyclient);
|
||||
xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
|
||||
xla::PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient);
|
||||
|
||||
// The function to call in `xla.cc` to add the bindings for this module.
|
||||
void BuildJaxjitSubmodule(pybind11::module& m);
|
||||
|
||||
} // namespace xla
|
||||
} // namespace jax
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
|
||||
|
@ -415,7 +415,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
BuildOpsSubmodule(&m);
|
||||
BuildOutfeedReceiverSubmodule(&m);
|
||||
BuildPytreeSubmodule(m);
|
||||
BuildJaxjitSubmodule(m);
|
||||
jax::BuildJaxjitSubmodule(m);
|
||||
jax::BuildPmapSubmodule(m);
|
||||
BuildTracebackSubmodule(m);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user