Move jax_jit functions to the jax
namespace.
PiperOrigin-RevId: 352208505 Change-Id: Iff1ac8e432f13d41f71b41b6e560c35765549043
This commit is contained in:
parent
3b94e9cfdb
commit
10c06df949
@ -53,7 +53,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace jax {
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
@ -180,11 +180,11 @@ H AbslHashValue(H h, const CallSignature& s) {
|
|||||||
|
|
||||||
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
||||||
// dynamic positional and keyword arguments), filling `arguments` in place.
|
// dynamic positional and keyword arguments), filling `arguments` in place.
|
||||||
Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
xla::Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||||
absl::Span<int const> static_argnums,
|
absl::Span<int const> static_argnums,
|
||||||
ParsedArgumentsAsBuffers& arguments) {
|
ParsedArgumentsAsBuffers& arguments) {
|
||||||
if (static_argnums.size() > args.size()) {
|
if (static_argnums.size() > args.size()) {
|
||||||
return InvalidArgument(
|
return xla::InvalidArgument(
|
||||||
"%s", "[jaxjit] Error with static argnums, executing the Python path.");
|
"%s", "[jaxjit] Error with static argnums, executing the Python path.");
|
||||||
}
|
}
|
||||||
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
|
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
|
||||||
@ -196,7 +196,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
|||||||
for (size_t i = 0; i < args.size(); ++i) {
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
|
if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
|
||||||
static_argnums.end()) {
|
static_argnums.end()) {
|
||||||
PyTreeDef pytree_def;
|
xla::PyTreeDef pytree_def;
|
||||||
pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
|
pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
|
||||||
arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
|
arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
|
||||||
} else {
|
} else {
|
||||||
@ -210,7 +210,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
|||||||
std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
|
std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
|
||||||
py_kwargs.end());
|
py_kwargs.end());
|
||||||
// We first intern the keys, then sort them (by name, as in the Python path)
|
// We first intern the keys, then sort them (by name, as in the Python path)
|
||||||
// (see also PyTreeDef::Flatten) and then create the signatures.
|
// (see also xla::PyTreeDef::Flatten) and then create the signatures.
|
||||||
// TODO(jblespiau): We should be able to sort the keys by interned-key
|
// TODO(jblespiau): We should be able to sort the keys by interned-key
|
||||||
// pointers, but this requires the Python compilation to do the same.
|
// pointers, but this requires the Python compilation to do the same.
|
||||||
arguments.signature.keyword_args.resize(kwargs.size());
|
arguments.signature.keyword_args.resize(kwargs.size());
|
||||||
@ -236,7 +236,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
|||||||
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
|
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
|
||||||
kwargs[i].second, arguments.flat_dynamic_args);
|
kwargs[i].second, arguments.flat_dynamic_args);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return xla::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -374,11 +374,11 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ToArgSignatureHandler =
|
using ToArgSignatureHandler =
|
||||||
std::function<StatusOr<ArgSignature>(py::handle, bool)>;
|
std::function<xla::StatusOr<ArgSignature>(py::handle, bool)>;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||||
bool jax_enable_x64) {
|
bool jax_enable_x64) {
|
||||||
static const absl::flat_hash_map<PyObject*,
|
static const absl::flat_hash_map<PyObject*,
|
||||||
ToArgSignatureHandler>* const handlers = [] {
|
ToArgSignatureHandler>* const handlers = [] {
|
||||||
auto p = new absl::flat_hash_map<PyObject*, ToArgSignatureHandler>();
|
auto p = new absl::flat_hash_map<PyObject*, ToArgSignatureHandler>();
|
||||||
@ -389,38 +389,40 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
|
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
|
||||||
|
|
||||||
// The 4 Python native types.
|
// The 4 Python native types.
|
||||||
ToArgSignatureHandler bool_handler = [](py::handle,
|
ToArgSignatureHandler bool_handler =
|
||||||
bool) -> StatusOr<ArgSignature> {
|
[](py::handle, bool) -> xla::StatusOr<ArgSignature> {
|
||||||
return ArgSignature(PrimitiveType::PRED, {}, true);
|
return ArgSignature(xla::PrimitiveType::PRED, {}, true);
|
||||||
};
|
};
|
||||||
ToArgSignatureHandler int_handler =
|
ToArgSignatureHandler int_handler =
|
||||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
return ArgSignature(PrimitiveType::S64, {}, true);
|
return ArgSignature(xla::PrimitiveType::S64, {}, true);
|
||||||
} else {
|
} else {
|
||||||
return ArgSignature(PrimitiveType::S32, {}, true);
|
return ArgSignature(xla::PrimitiveType::S32, {}, true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ToArgSignatureHandler float_handler =
|
ToArgSignatureHandler float_handler =
|
||||||
[&dtypes](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[&dtypes](py::handle h,
|
||||||
|
bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
// Only Python native types has a True weak_type.
|
// Only Python native types has a True weak_type.
|
||||||
bool weak_type = !py::isinstance(h, dtypes.np_float64);
|
bool weak_type = !py::isinstance(h, dtypes.np_float64);
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
return ArgSignature(PrimitiveType::F64, {}, weak_type);
|
return ArgSignature(xla::PrimitiveType::F64, {}, weak_type);
|
||||||
} else {
|
} else {
|
||||||
return ArgSignature(PrimitiveType::F32, {}, weak_type);
|
return ArgSignature(xla::PrimitiveType::F32, {}, weak_type);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ToArgSignatureHandler complex_handler =
|
ToArgSignatureHandler complex_handler =
|
||||||
[&dtypes](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[&dtypes](py::handle h,
|
||||||
|
bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
// Note that this branch is also taken for np.complex128:
|
// Note that this branch is also taken for np.complex128:
|
||||||
// isinstance(np.complex128(3), complex) returns True
|
// isinstance(np.complex128(3), complex) returns True
|
||||||
// isinstance(np.complex64(3), complex) returns False
|
// isinstance(np.complex64(3), complex) returns False
|
||||||
bool weak_type = !py::isinstance(h, dtypes.np_complex128);
|
bool weak_type = !py::isinstance(h, dtypes.np_complex128);
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
return ArgSignature(PrimitiveType::C128, {}, weak_type);
|
return ArgSignature(xla::PrimitiveType::C128, {}, weak_type);
|
||||||
} else {
|
} else {
|
||||||
return ArgSignature(PrimitiveType::C64, {}, weak_type);
|
return ArgSignature(xla::PrimitiveType::C64, {}, weak_type);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -432,18 +434,19 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
// The Buffer types
|
// The Buffer types
|
||||||
// PyBuffer necessarily has a trivial LazyExpr, no need to check it.
|
// PyBuffer necessarily has a trivial LazyExpr, no need to check it.
|
||||||
ToArgSignatureHandler buffer_handler =
|
ToArgSignatureHandler buffer_handler =
|
||||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
|
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
|
||||||
bool weak_type = py::cast<py::bool_>(h.attr("aval").attr("weak_type"));
|
bool weak_type = py::cast<py::bool_>(h.attr("aval").attr("weak_type"));
|
||||||
return ArgSignature(buffer->buffer()->on_host_shape().element_type(),
|
return ArgSignature(buffer->buffer()->on_host_shape().element_type(),
|
||||||
buffer->buffer()->on_host_shape().dimensions(),
|
buffer->buffer()->on_host_shape().dimensions(),
|
||||||
weak_type);
|
weak_type);
|
||||||
};
|
};
|
||||||
(*p)[py::type::handle_of<DeviceArrayBase>().ptr()] = buffer_handler;
|
(*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] = buffer_handler;
|
||||||
ToArgSignatureHandler device_array_handler =
|
ToArgSignatureHandler device_array_handler =
|
||||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
py::handle aval = h.attr("aval");
|
py::handle aval = h.attr("aval");
|
||||||
TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(aval.attr("dtype")));
|
TF_ASSIGN_OR_RETURN(auto dtype,
|
||||||
|
xla::DtypeToPrimitiveType(aval.attr("dtype")));
|
||||||
return ArgSignature(dtype,
|
return ArgSignature(dtype,
|
||||||
py::cast<std::vector<int64>>(aval.attr("shape")),
|
py::cast<std::vector<int64>>(aval.attr("shape")),
|
||||||
py::cast<py::bool_>(aval.attr("weak_type")));
|
py::cast<py::bool_>(aval.attr("weak_type")));
|
||||||
@ -452,10 +455,10 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
(*p)[device_array.ptr()] = device_array_handler;
|
(*p)[device_array.ptr()] = device_array_handler;
|
||||||
|
|
||||||
ToArgSignatureHandler numpy_handler =
|
ToArgSignatureHandler numpy_handler =
|
||||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
py::array numpy_array = py::cast<py::array>(h);
|
py::array numpy_array = py::cast<py::array>(h);
|
||||||
if (IsFloat0(numpy_array)) {
|
if (IsFloat0(numpy_array)) {
|
||||||
return InvalidArgument(
|
return xla::InvalidArgument(
|
||||||
"float0 numpy arrays not supported in C++. "
|
"float0 numpy arrays not supported in C++. "
|
||||||
"Falling back to Python.");
|
"Falling back to Python.");
|
||||||
}
|
}
|
||||||
@ -463,11 +466,11 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
const py::dtype raw_dtype = numpy_array.dtype();
|
const py::dtype raw_dtype = numpy_array.dtype();
|
||||||
const py::dtype* to_dtype = DtypeTo32BitDtype(raw_dtype);
|
const py::dtype* to_dtype = DtypeTo32BitDtype(raw_dtype);
|
||||||
|
|
||||||
PrimitiveType dtype;
|
xla::PrimitiveType dtype;
|
||||||
if (to_dtype) {
|
if (to_dtype) {
|
||||||
TF_ASSIGN_OR_RETURN(dtype, DtypeToPrimitiveType(*to_dtype));
|
TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(*to_dtype));
|
||||||
} else {
|
} else {
|
||||||
TF_ASSIGN_OR_RETURN(dtype, DtypeToPrimitiveType(raw_dtype));
|
TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(raw_dtype));
|
||||||
}
|
}
|
||||||
// We need the reinterpret_cast for the OSS version to build.
|
// We need the reinterpret_cast for the OSS version to build.
|
||||||
return ArgSignature(dtype,
|
return ArgSignature(dtype,
|
||||||
@ -477,7 +480,7 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
/*weak_type=*/false);
|
/*weak_type=*/false);
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(auto dtype,
|
TF_ASSIGN_OR_RETURN(auto dtype,
|
||||||
DtypeToPrimitiveType(numpy_array.dtype()));
|
xla::DtypeToPrimitiveType(numpy_array.dtype()));
|
||||||
return ArgSignature(dtype,
|
return ArgSignature(dtype,
|
||||||
absl::MakeConstSpan(reinterpret_cast<const int64*>(
|
absl::MakeConstSpan(reinterpret_cast<const int64*>(
|
||||||
numpy_array.shape()),
|
numpy_array.shape()),
|
||||||
@ -489,27 +492,28 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
(*p)[ndarray.ptr()] = numpy_handler;
|
(*p)[ndarray.ptr()] = numpy_handler;
|
||||||
|
|
||||||
ToArgSignatureHandler np_uint64_handler =
|
ToArgSignatureHandler np_uint64_handler =
|
||||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
return ArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false);
|
return ArgSignature(xla::PrimitiveType::U64, {}, /*weak_type=*/false);
|
||||||
} else {
|
} else {
|
||||||
return ArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false);
|
return ArgSignature(xla::PrimitiveType::U32, {}, /*weak_type=*/false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ToArgSignatureHandler np_int_handler =
|
ToArgSignatureHandler np_int_handler =
|
||||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
return ArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false);
|
return ArgSignature(xla::PrimitiveType::S64, {}, /*weak_type=*/false);
|
||||||
} else {
|
} else {
|
||||||
return ArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false);
|
return ArgSignature(xla::PrimitiveType::S32, {}, /*weak_type=*/false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ToArgSignatureHandler numpy_array_handler =
|
ToArgSignatureHandler numpy_array_handler =
|
||||||
[](py::handle h, bool jax_enable_x64) -> StatusOr<ArgSignature> {
|
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||||
// This block deals with all numpy scalar types, except for int64_dt,
|
// This block deals with all numpy scalar types, except for int64_dt,
|
||||||
// float64_dt and complex128_dt which are taken care of in previous if
|
// float64_dt and complex128_dt which are taken care of in previous if
|
||||||
// blocks.
|
// blocks.
|
||||||
TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(h.attr("dtype")));
|
TF_ASSIGN_OR_RETURN(auto dtype,
|
||||||
|
xla::DtypeToPrimitiveType(h.attr("dtype")));
|
||||||
return ArgSignature(dtype, {}, /*weak_type=*/false);
|
return ArgSignature(dtype, {}, /*weak_type=*/false);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -545,7 +549,7 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
return res->second(arg, jax_enable_x64);
|
return res->second(arg, jax_enable_x64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return InvalidArgument(
|
return xla::InvalidArgument(
|
||||||
"%s", absl::StrCat("Not supported: The C++ ToArgSignature only accepts "
|
"%s", absl::StrCat("Not supported: The C++ ToArgSignature only accepts "
|
||||||
"Buffer/DeviceArray/ShardedDeviceArray, Numpy "
|
"Buffer/DeviceArray/ShardedDeviceArray, Numpy "
|
||||||
"arrays scalars of supported types "
|
"arrays scalars of supported types "
|
||||||
@ -557,17 +561,17 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
|
using DevicePutFunc = std::function<xla::StatusOr<DevicePutResult>(
|
||||||
py::handle, PjRtDevice*, bool, xla::PyClient&)>;
|
py::handle, xla::PjRtDevice*, bool, xla::PyClient&)>;
|
||||||
|
|
||||||
DevicePutResult HandleBool(py::handle h, PjRtDevice* to_device,
|
DevicePutResult HandleBool(py::handle h, xla::PjRtDevice* to_device,
|
||||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||||
return DevicePutResult(ConvertToScalarBuffer<bool, py::bool_>(
|
return DevicePutResult(ConvertToScalarBuffer<bool, py::bool_>(
|
||||||
h, pyclient.pjrt_client(), to_device),
|
h, pyclient.pjrt_client(), to_device),
|
||||||
/*weak_type=*/true);
|
/*weak_type=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device,
|
DevicePutResult HandleInt(py::handle obj, xla::PjRtDevice* to_device,
|
||||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
|
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
|
||||||
@ -581,9 +585,10 @@ DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool weak_type>
|
template <bool weak_type>
|
||||||
StatusOr<DevicePutResult> HandleFloat(py::handle h, PjRtDevice* to_device,
|
xla::StatusOr<DevicePutResult> HandleFloat(py::handle h,
|
||||||
bool jax_enable_x64,
|
xla::PjRtDevice* to_device,
|
||||||
xla::PyClient& pyclient) {
|
bool jax_enable_x64,
|
||||||
|
xla::PyClient& pyclient) {
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
return DevicePutResult(ConvertToScalarBuffer<double, py::float_>(
|
return DevicePutResult(ConvertToScalarBuffer<double, py::float_>(
|
||||||
h, pyclient.pjrt_client(), to_device),
|
h, pyclient.pjrt_client(), to_device),
|
||||||
@ -596,9 +601,10 @@ StatusOr<DevicePutResult> HandleFloat(py::handle h, PjRtDevice* to_device,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool weak_type>
|
template <bool weak_type>
|
||||||
StatusOr<DevicePutResult> HandleComplex(py::handle h, PjRtDevice* to_device,
|
xla::StatusOr<DevicePutResult> HandleComplex(py::handle h,
|
||||||
bool jax_enable_x64,
|
xla::PjRtDevice* to_device,
|
||||||
xla::PyClient& pyclient) {
|
bool jax_enable_x64,
|
||||||
|
xla::PyClient& pyclient) {
|
||||||
// This branch is also taken for np.complex128:
|
// This branch is also taken for np.complex128:
|
||||||
// isinstance(np.complex128(3), complex) returns True
|
// isinstance(np.complex128(3), complex) returns True
|
||||||
// isinstance(np.complex64(3), complex) returns False
|
// isinstance(np.complex64(3), complex) returns False
|
||||||
@ -628,22 +634,22 @@ StatusOr<DevicePutResult> HandleComplex(py::handle h, PjRtDevice* to_device,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
|
xla::StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
|
||||||
PjRtDevice* to_device,
|
xla::PjRtDevice* to_device,
|
||||||
bool jax_enable_x64,
|
bool jax_enable_x64,
|
||||||
xla::PyClient& pyclient) {
|
xla::PyClient& pyclient) {
|
||||||
if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
|
if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
|
||||||
return InvalidArgument(
|
return xla::InvalidArgument(
|
||||||
"Non-trivial lazy expression not supported in C++. "
|
"Non-trivial lazy expression not supported in C++. "
|
||||||
"Falling back to Python.");
|
"Falling back to Python.");
|
||||||
}
|
}
|
||||||
PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
|
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
|
||||||
bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
|
bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
|
||||||
// Same block as in the previous `if (is_py_buffer)`.
|
// Same block as in the previous `if (is_py_buffer)`.
|
||||||
if (buffer->device().contents == to_device) {
|
if (buffer->device().contents == to_device) {
|
||||||
return DevicePutResult(buffer->buffer(), weak_type);
|
return DevicePutResult(buffer->buffer(), weak_type);
|
||||||
} else {
|
} else {
|
||||||
std::unique_ptr<PjRtBuffer> copied_buffer =
|
std::unique_ptr<xla::PjRtBuffer> copied_buffer =
|
||||||
ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
|
ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
|
||||||
return DevicePutResult(std::move(copied_buffer), weak_type);
|
return DevicePutResult(std::move(copied_buffer), weak_type);
|
||||||
}
|
}
|
||||||
@ -651,7 +657,7 @@ StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
|
|||||||
|
|
||||||
// Do not convert types, and only call PjRtBufferFromPyval, independently
|
// Do not convert types, and only call PjRtBufferFromPyval, independently
|
||||||
// of the value of jax_enable_x64.
|
// of the value of jax_enable_x64.
|
||||||
DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device,
|
DevicePutResult HandleBufferFromPyval(py::handle h, xla::PjRtDevice* to_device,
|
||||||
bool jax_enable_x64,
|
bool jax_enable_x64,
|
||||||
xla::PyClient& pyclient) {
|
xla::PyClient& pyclient) {
|
||||||
std::unique_ptr<xla::PjRtBuffer> buffer =
|
std::unique_ptr<xla::PjRtBuffer> buffer =
|
||||||
@ -662,7 +668,7 @@ DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device,
|
|||||||
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
|
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device,
|
DevicePutResult HandleNpBool(py::handle h, xla::PjRtDevice* to_device,
|
||||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
|
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
|
||||||
@ -675,7 +681,7 @@ DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device,
|
DevicePutResult HandleUint64(py::handle h, xla::PjRtDevice* to_device,
|
||||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||||
if (jax_enable_x64) {
|
if (jax_enable_x64) {
|
||||||
std::unique_ptr<xla::PjRtBuffer> buffer =
|
std::unique_ptr<xla::PjRtBuffer> buffer =
|
||||||
@ -698,14 +704,15 @@ DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<DevicePutResult> HandleNdarray(py::handle h, PjRtDevice* to_device,
|
xla::StatusOr<DevicePutResult> HandleNdarray(py::handle h,
|
||||||
bool jax_enable_x64,
|
xla::PjRtDevice* to_device,
|
||||||
xla::PyClient& pyclient) {
|
bool jax_enable_x64,
|
||||||
|
xla::PyClient& pyclient) {
|
||||||
py::array numpy_array = py::cast<py::array>(h);
|
py::array numpy_array = py::cast<py::array>(h);
|
||||||
if (IsFloat0(numpy_array)) {
|
if (IsFloat0(numpy_array)) {
|
||||||
return InvalidArgument("%s",
|
return xla::InvalidArgument("%s",
|
||||||
"float0 numpy arrays not supported in C++. "
|
"float0 numpy arrays not supported in C++. "
|
||||||
"Falling back to Python.");
|
"Falling back to Python.");
|
||||||
}
|
}
|
||||||
// If jax_enable_x64 is not set, we need to coerce 32 bits types.
|
// If jax_enable_x64 is not set, we need to coerce 32 bits types.
|
||||||
// Note that this is calling back to Python!
|
// Note that this is calling back to Python!
|
||||||
@ -727,9 +734,10 @@ StatusOr<DevicePutResult> HandleNdarray(py::handle h, PjRtDevice* to_device,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
|
||||||
bool jax_enable_x64,
|
xla::PjRtDevice* to_device,
|
||||||
xla::PyClient& pyclient) {
|
bool jax_enable_x64,
|
||||||
|
xla::PyClient& pyclient) {
|
||||||
static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
|
static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
|
||||||
[] {
|
[] {
|
||||||
auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
|
auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
|
||||||
@ -751,7 +759,8 @@ StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
|||||||
const auto pxla_module = py::module::import("jax.interpreters.pxla");
|
const auto pxla_module = py::module::import("jax.interpreters.pxla");
|
||||||
const auto& sda = pxla_module.attr("ShardedDeviceArray");
|
const auto& sda = pxla_module.attr("ShardedDeviceArray");
|
||||||
(*p)[device_array.ptr()] = HandleDeviceArray;
|
(*p)[device_array.ptr()] = HandleDeviceArray;
|
||||||
(*p)[py::type::handle_of<DeviceArrayBase>().ptr()] = HandleDeviceArray;
|
(*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] =
|
||||||
|
HandleDeviceArray;
|
||||||
(*p)[sda.ptr()] = HandleBufferFromPyval;
|
(*p)[sda.ptr()] = HandleBufferFromPyval;
|
||||||
// Numpy arrays.
|
// Numpy arrays.
|
||||||
(*p)[numpy.attr("ndarray").ptr()] = HandleNdarray;
|
(*p)[numpy.attr("ndarray").ptr()] = HandleNdarray;
|
||||||
@ -786,7 +795,7 @@ StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
|||||||
return res->second(arg, to_device, jax_enable_x64, pyclient);
|
return res->second(arg, to_device, jax_enable_x64, pyclient);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return InvalidArgument(
|
return xla::InvalidArgument(
|
||||||
"%s", absl::StrCat(
|
"%s", absl::StrCat(
|
||||||
"Not supported: The C++ jax jit execution path, only accepts "
|
"Not supported: The C++ jax jit execution path, only accepts "
|
||||||
"DeviceArray, Numpy arrays scalars of supported types "
|
"DeviceArray, Numpy arrays scalars of supported types "
|
||||||
@ -801,7 +810,7 @@ namespace {
|
|||||||
|
|
||||||
struct CacheEntry {
|
struct CacheEntry {
|
||||||
std::shared_ptr<xla::PyExecutable> executable;
|
std::shared_ptr<xla::PyExecutable> executable;
|
||||||
PyTreeDef out_pytree_def;
|
xla::PyTreeDef out_pytree_def;
|
||||||
// We use Python types within the vector because this is what we will be
|
// We use Python types within the vector because this is what we will be
|
||||||
// returning to Python. No need to convert back and forth.
|
// returning to Python. No need to convert back and forth.
|
||||||
// We need py::object to maintain the objects alive.
|
// We need py::object to maintain the objects alive.
|
||||||
@ -817,7 +826,7 @@ struct CacheEntry {
|
|||||||
// a signature and if the object has been insterted already, other threads
|
// a signature and if the object has been insterted already, other threads
|
||||||
// will wait for the notification.
|
// will wait for the notification.
|
||||||
absl::Notification compilation_complete;
|
absl::Notification compilation_complete;
|
||||||
absl::optional<Status> compilation_error = absl::nullopt;
|
absl::optional<xla::Status> compilation_error = absl::nullopt;
|
||||||
// Trivial computation will fallback to Python.
|
// Trivial computation will fallback to Python.
|
||||||
// Running a jax(pmap) will also fallback to Python.
|
// Running a jax(pmap) will also fallback to Python.
|
||||||
bool fall_back_to_python = false;
|
bool fall_back_to_python = false;
|
||||||
@ -896,7 +905,7 @@ class CompiledFunction {
|
|||||||
// the `default_device_` which will be used as the targeted device. In
|
// the `default_device_` which will be used as the targeted device. In
|
||||||
// which case, we will always copy input buffers to this device.
|
// which case, we will always copy input buffers to this device.
|
||||||
std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
|
std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
|
||||||
xla::ClientAndPtr<PjRtDevice> default_pydevice_;
|
xla::ClientAndPtr<xla::PjRtDevice> default_pydevice_;
|
||||||
xla::PjRtDevice* default_device_ = nullptr;
|
xla::PjRtDevice* default_device_ = nullptr;
|
||||||
bool is_committed_;
|
bool is_committed_;
|
||||||
};
|
};
|
||||||
@ -924,11 +933,12 @@ CompiledFunction::~CompiledFunction() {
|
|||||||
// Converts flattened arguments contained in ParsedArgumentsAsBuffers in
|
// Converts flattened arguments contained in ParsedArgumentsAsBuffers in
|
||||||
// place. If arguments are `DeviceArray`, they must all be on the same `Device`.
|
// place. If arguments are `DeviceArray`, they must all be on the same `Device`.
|
||||||
//
|
//
|
||||||
// Returns `OkStatus()` on success. Returning an error should lead to calling
|
// Returns `Okxla::Status()` on success. Returning an error should lead to
|
||||||
// the Python fallback.
|
// calling the Python fallback.
|
||||||
Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||||
xla::PjRtDevice* default_device, bool is_committed,
|
xla::PjRtDevice* default_device,
|
||||||
ParsedArgumentsAsBuffers& arguments) {
|
bool is_committed,
|
||||||
|
ParsedArgumentsAsBuffers& arguments) {
|
||||||
std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
|
std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
|
||||||
auto& keep_alive = arguments.keep_alive;
|
auto& keep_alive = arguments.keep_alive;
|
||||||
|
|
||||||
@ -953,7 +963,8 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
|||||||
for (py::handle arg : arguments.flat_dynamic_args) {
|
for (py::handle arg : arguments.flat_dynamic_args) {
|
||||||
// We specically only deal with DeviceArray (not ShardedDeviceArray).
|
// We specically only deal with DeviceArray (not ShardedDeviceArray).
|
||||||
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
|
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
|
||||||
if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
|
if (py::isinstance<xla::PyBuffer>(arg) ||
|
||||||
|
arg.get_type().is(device_array)) {
|
||||||
xla::PyBuffer* buffer;
|
xla::PyBuffer* buffer;
|
||||||
if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
|
if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
|
||||||
continue;
|
continue;
|
||||||
@ -962,7 +973,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
|||||||
// This can fail, e.g. when device_buffer is a `DeviceConstant`.
|
// This can fail, e.g. when device_buffer is a `DeviceConstant`.
|
||||||
buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
|
buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
|
||||||
} catch (const py::cast_error& e) {
|
} catch (const py::cast_error& e) {
|
||||||
return InvalidArgument(
|
return xla::InvalidArgument(
|
||||||
"%s",
|
"%s",
|
||||||
absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
|
absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
|
||||||
"`device_buffer` field is of type ",
|
"`device_buffer` field is of type ",
|
||||||
@ -995,7 +1006,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
|||||||
TF_ASSIGN_OR_RETURN(DevicePutResult on_device,
|
TF_ASSIGN_OR_RETURN(DevicePutResult on_device,
|
||||||
DevicePut(arg, data_device, jax_enable_x64, pyclient));
|
DevicePut(arg, data_device, jax_enable_x64, pyclient));
|
||||||
|
|
||||||
PjRtBuffer* buffer = on_device.buffer;
|
xla::PjRtBuffer* buffer = on_device.buffer;
|
||||||
arg_buffers.push_back(buffer);
|
arg_buffers.push_back(buffer);
|
||||||
if (on_device.owned_buffer) {
|
if (on_device.owned_buffer) {
|
||||||
keep_alive.emplace_back(std::move(on_device.owned_buffer));
|
keep_alive.emplace_back(std::move(on_device.owned_buffer));
|
||||||
@ -1005,7 +1016,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
|||||||
buffer->on_host_shape().dimensions(), on_device.weak_type);
|
buffer->on_host_shape().dimensions(), on_device.weak_type);
|
||||||
arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
|
arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return xla::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
CacheEntry* CompiledFunction::GetCacheEntryIfPresent(
|
CacheEntry* CompiledFunction::GetCacheEntryIfPresent(
|
||||||
@ -1063,7 +1074,7 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
|
|||||||
// The presence of jit(pmap) is detected from Python.
|
// The presence of jit(pmap) is detected from Python.
|
||||||
CHECK_EQ(num_devices, 1);
|
CHECK_EQ(num_devices, 1);
|
||||||
|
|
||||||
auto out_tree = py::cast<PyTreeDef>(executable_handlers_out_tree[1]);
|
auto out_tree = py::cast<xla::PyTreeDef>(executable_handlers_out_tree[1]);
|
||||||
cache_entry->out_pytree_def = std::move(out_tree);
|
cache_entry->out_pytree_def = std::move(out_tree);
|
||||||
|
|
||||||
cache_entry->sticky_device =
|
cache_entry->sticky_device =
|
||||||
@ -1105,7 +1116,7 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
|||||||
if (!default_device_) {
|
if (!default_device_) {
|
||||||
py::object device_and_is_committed = get_device_();
|
py::object device_and_is_committed = get_device_();
|
||||||
try {
|
try {
|
||||||
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
|
default_pydevice_ = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
|
||||||
device_and_is_committed.attr("default_device"));
|
device_and_is_committed.attr("default_device"));
|
||||||
} catch (const py::cast_error& e) {
|
} catch (const py::cast_error& e) {
|
||||||
// Pathways and Cloud TPU 2VM runtime.
|
// Pathways and Cloud TPU 2VM runtime.
|
||||||
@ -1217,16 +1228,16 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
|||||||
// (b) it does not set the device stickiness yet.
|
// (b) it does not set the device stickiness yet.
|
||||||
// TODO(jblespiau): Finish the replacement of the Python feature.
|
// TODO(jblespiau): Finish the replacement of the Python feature.
|
||||||
jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64,
|
jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64,
|
||||||
ClientAndPtr<PjRtDevice> to_device) {
|
xla::ClientAndPtr<xla::PjRtDevice> to_device) {
|
||||||
std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
|
std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
|
||||||
StatusOr<DevicePutResult> results =
|
xla::StatusOr<DevicePutResult> results =
|
||||||
DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient);
|
DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient);
|
||||||
if (!results.ok()) {
|
if (!results.ok()) {
|
||||||
throw std::runtime_error(results.status().error_message());
|
throw std::runtime_error(results.status().error_message());
|
||||||
}
|
}
|
||||||
if (results->owned_buffer) {
|
if (results->owned_buffer) {
|
||||||
auto buffer = std::make_unique<PyBuffer>(
|
auto buffer = std::make_unique<xla::PyBuffer>(
|
||||||
pyclient, std::move(results->owned_buffer), Traceback::Get());
|
pyclient, std::move(results->owned_buffer), xla::Traceback::Get());
|
||||||
|
|
||||||
static const auto* jax_core =
|
static const auto* jax_core =
|
||||||
new py::module(py::module::import("jax.core"));
|
new py::module(py::module::import("jax.core"));
|
||||||
@ -1248,9 +1259,10 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
|||||||
[](const ArgSignature& sig) {
|
[](const ArgSignature& sig) {
|
||||||
return PrimitiveTypeToDtype(sig.dtype);
|
return PrimitiveTypeToDtype(sig.dtype);
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly("shape",
|
||||||
"shape",
|
[](const ArgSignature& sig) {
|
||||||
[](const ArgSignature& sig) { return IntSpanToTuple(sig.shape); })
|
return xla::IntSpanToTuple(sig.shape);
|
||||||
|
})
|
||||||
.def_readonly("weak_type", &ArgSignature::weak_type);
|
.def_readonly("weak_type", &ArgSignature::weak_type);
|
||||||
jitlib.def("_ArgSignatureOfValue", &ArgSignatureOfValue);
|
jitlib.def("_ArgSignatureOfValue", &ArgSignatureOfValue);
|
||||||
|
|
||||||
@ -1269,4 +1281,4 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
|||||||
jitlib.def("_is_trivial", &IsTrivialLazyExpr);
|
jitlib.def("_is_trivial", &IsTrivialLazyExpr);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace jax
|
||||||
|
@ -24,15 +24,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/python/pytree.h"
|
#include "tensorflow/compiler/xla/python/pytree.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace jax {
|
||||||
|
|
||||||
// Describes the abstract shape and dtype of an argument.
|
// Describes the abstract shape and dtype of an argument.
|
||||||
struct ArgSignature {
|
struct ArgSignature {
|
||||||
ArgSignature(PrimitiveType dtype, absl::Span<const int64> shape,
|
ArgSignature(xla::PrimitiveType dtype, absl::Span<const int64> shape,
|
||||||
bool weak_type)
|
bool weak_type)
|
||||||
: dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {}
|
: dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {}
|
||||||
// This is the XLA dtype of the object.
|
// This is the XLA dtype of the object.
|
||||||
const PrimitiveType dtype;
|
const xla::PrimitiveType dtype;
|
||||||
const absl::InlinedVector<int64, 4> shape;
|
const absl::InlinedVector<int64, 4> shape;
|
||||||
// JAX arguments can be of weak type, if and only if they are Python scalars
|
// JAX arguments can be of weak type, if and only if they are Python scalars
|
||||||
// or `DeviceArray` values such that `aval.weak_type` is true.
|
// or `DeviceArray` values such that `aval.weak_type` is true.
|
||||||
@ -66,7 +66,7 @@ struct CallSignature {
|
|||||||
// To avoid comparing strings, we intern the kwargs strings.
|
// To avoid comparing strings, we intern the kwargs strings.
|
||||||
// The compilation cache holds a reference to all the keys.
|
// The compilation cache holds a reference to all the keys.
|
||||||
pybind11::handle key;
|
pybind11::handle key;
|
||||||
PyTreeDef value_treedef;
|
xla::PyTreeDef value_treedef;
|
||||||
bool operator==(const KwargEntry& other) const {
|
bool operator==(const KwargEntry& other) const {
|
||||||
return key.ptr() == other.key.ptr() &&
|
return key.ptr() == other.key.ptr() &&
|
||||||
value_treedef == other.value_treedef;
|
value_treedef == other.value_treedef;
|
||||||
@ -78,13 +78,13 @@ struct CallSignature {
|
|||||||
// order of their argnum index.
|
// order of their argnum index.
|
||||||
std::vector<pybind11::object> static_args;
|
std::vector<pybind11::object> static_args;
|
||||||
// A PyTreeDef for each positional dynamic (i.e. not static) argument.
|
// A PyTreeDef for each positional dynamic (i.e. not static) argument.
|
||||||
std::vector<PyTreeDef> dynamic_positional_args_treedef;
|
std::vector<xla::PyTreeDef> dynamic_positional_args_treedef;
|
||||||
// Keyword arguments. Sorted by the keyword name.
|
// Keyword arguments. Sorted by the keyword name.
|
||||||
std::vector<KwargEntry> keyword_args;
|
std::vector<KwargEntry> keyword_args;
|
||||||
// Shape and dtype for both the dynamic positional arguments and the keyword
|
// Shape and dtype for both the dynamic positional arguments and the keyword
|
||||||
// arguments (sorted by keyword name).
|
// arguments (sorted by keyword name).
|
||||||
std::vector<ArgSignature> dynamic_args_signatures;
|
std::vector<ArgSignature> dynamic_args_signatures;
|
||||||
PjRtDevice* device;
|
xla::PjRtDevice* device;
|
||||||
|
|
||||||
bool operator==(const CallSignature& other) const;
|
bool operator==(const CallSignature& other) const;
|
||||||
bool operator!=(const CallSignature& other) const {
|
bool operator!=(const CallSignature& other) const {
|
||||||
@ -132,28 +132,28 @@ struct ParsedArgumentsAsBuffers {
|
|||||||
|
|
||||||
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
||||||
// dynamic positional and keyword arguments), filling `arguments` in place.
|
// dynamic positional and keyword arguments), filling `arguments` in place.
|
||||||
Status ParseArguments(const pybind11::args& args,
|
xla::Status ParseArguments(const pybind11::args& args,
|
||||||
const pybind11::kwargs& py_kwargs,
|
const pybind11::kwargs& py_kwargs,
|
||||||
absl::Span<int const> static_argnums,
|
absl::Span<int const> static_argnums,
|
||||||
ParsedArgumentsAsBuffers& arguments);
|
ParsedArgumentsAsBuffers& arguments);
|
||||||
|
|
||||||
struct DevicePutResult {
|
struct DevicePutResult {
|
||||||
explicit DevicePutResult(PjRtBuffer* b, bool weak_type)
|
explicit DevicePutResult(xla::PjRtBuffer* b, bool weak_type)
|
||||||
: buffer(b), weak_type(weak_type), owned_buffer(nullptr) {}
|
: buffer(b), weak_type(weak_type), owned_buffer(nullptr) {}
|
||||||
DevicePutResult(std::unique_ptr<PjRtBuffer> new_buffer, bool weak_type)
|
DevicePutResult(std::unique_ptr<xla::PjRtBuffer> new_buffer, bool weak_type)
|
||||||
: buffer(new_buffer.get()),
|
: buffer(new_buffer.get()),
|
||||||
weak_type(weak_type),
|
weak_type(weak_type),
|
||||||
owned_buffer(std::move(new_buffer)) {}
|
owned_buffer(std::move(new_buffer)) {}
|
||||||
|
|
||||||
PjRtBuffer* buffer;
|
xla::PjRtBuffer* buffer;
|
||||||
bool weak_type;
|
bool weak_type;
|
||||||
std::unique_ptr<PjRtBuffer> owned_buffer;
|
std::unique_ptr<xla::PjRtBuffer> owned_buffer;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns the ArgSignature associated with an argument. Returns an error if
|
// Returns the ArgSignature associated with an argument. Returns an error if
|
||||||
// the argument is not supported.
|
// the argument is not supported.
|
||||||
StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||||
bool jax_enable_x64);
|
bool jax_enable_x64);
|
||||||
|
|
||||||
// Moves a device-like object to be on device.
|
// Moves a device-like object to be on device.
|
||||||
// - If the object is already on device, `owned_buffer` will be nullptr.
|
// - If the object is already on device, `owned_buffer` will be nullptr.
|
||||||
@ -161,15 +161,17 @@ StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
|||||||
// `owned_buffer`.
|
// `owned_buffer`.
|
||||||
// In all cases, `buffer` will point to the already existing or newly created
|
// In all cases, `buffer` will point to the already existing or newly created
|
||||||
// buffer.
|
// buffer.
|
||||||
// If `obj` is not convertible to a `PjRtBuffer` from C++, an error will be
|
// If `obj` is not convertible to a `xla::PjRtBuffer` from C++, an error will be
|
||||||
// returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not
|
// returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not
|
||||||
// supported yet.
|
// supported yet.
|
||||||
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
|
||||||
bool jax_enable_x64, PyClient& pyclient);
|
xla::PjRtDevice* to_device,
|
||||||
|
bool jax_enable_x64,
|
||||||
|
xla::PyClient& pyclient);
|
||||||
|
|
||||||
// The function to call in `xla.cc` to add the bindings for this module.
|
// The function to call in `xla.cc` to add the bindings for this module.
|
||||||
void BuildJaxjitSubmodule(pybind11::module& m);
|
void BuildJaxjitSubmodule(pybind11::module& m);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace jax
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
|
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
|
||||||
|
@ -415,7 +415,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
BuildOpsSubmodule(&m);
|
BuildOpsSubmodule(&m);
|
||||||
BuildOutfeedReceiverSubmodule(&m);
|
BuildOutfeedReceiverSubmodule(&m);
|
||||||
BuildPytreeSubmodule(m);
|
BuildPytreeSubmodule(m);
|
||||||
BuildJaxjitSubmodule(m);
|
jax::BuildJaxjitSubmodule(m);
|
||||||
jax::BuildPmapSubmodule(m);
|
jax::BuildPmapSubmodule(m);
|
||||||
BuildTracebackSubmodule(m);
|
BuildTracebackSubmodule(m);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user