Add support for more numpy types, and use a map-lookup design.
PiperOrigin-RevId: 351424260 Change-Id: I7d6e81574e0e3583ab4dbf59fdde75cfb5c951d2
This commit is contained in:
parent
544d1eb908
commit
5fff2a4bca
@ -266,6 +266,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//third_party/python_runtime:headers", # build_cleaner: keep
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -26,6 +26,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/python/jax_jit.h"
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <exception>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
@ -238,7 +240,56 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct NumpyScalarTypes {
|
||||
py::object np_bool;
|
||||
py::object np_int8;
|
||||
py::object np_int16;
|
||||
py::object np_int32;
|
||||
py::object np_int64;
|
||||
py::object np_uint8;
|
||||
py::object np_uint16;
|
||||
py::object np_uint32;
|
||||
py::object np_uint64;
|
||||
py::object np_float16;
|
||||
py::object np_float32;
|
||||
py::object np_float64;
|
||||
py::object np_complex64;
|
||||
py::object np_complex128;
|
||||
py::object np_longlong;
|
||||
py::object np_intc;
|
||||
};
|
||||
|
||||
const NumpyScalarTypes& GetNumpyScalarTypes() {
|
||||
static const NumpyScalarTypes* singleton = []() {
|
||||
// Use Designated initializers when they are available.
|
||||
const auto numpy = py::module::import("numpy");
|
||||
NumpyScalarTypes* dtypes = new NumpyScalarTypes();
|
||||
dtypes->np_bool = py::object(numpy.attr("bool_"));
|
||||
dtypes->np_int8 = py::object(numpy.attr("int8"));
|
||||
dtypes->np_int16 = py::object(numpy.attr("int16"));
|
||||
dtypes->np_int32 = py::object(numpy.attr("int32"));
|
||||
dtypes->np_int64 = py::object(numpy.attr("int64"));
|
||||
dtypes->np_uint8 = py::object(numpy.attr("uint8"));
|
||||
dtypes->np_uint16 = py::object(numpy.attr("uint16"));
|
||||
dtypes->np_uint32 = py::object(numpy.attr("uint32"));
|
||||
dtypes->np_uint64 = py::object(numpy.attr("uint64"));
|
||||
dtypes->np_float16 = py::object(numpy.attr("float16"));
|
||||
dtypes->np_float32 = py::object(numpy.attr("float32"));
|
||||
dtypes->np_float64 = py::object(numpy.attr("float64"));
|
||||
dtypes->np_complex64 = py::object(numpy.attr("complex64"));
|
||||
dtypes->np_complex128 = py::object(numpy.attr("complex128"));
|
||||
dtypes->np_longlong = py::object(numpy.attr("longlong"));
|
||||
dtypes->np_intc = py::object(numpy.attr("intc"));
|
||||
|
||||
return dtypes;
|
||||
}();
|
||||
|
||||
return *singleton;
|
||||
}
|
||||
|
||||
const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) {
|
||||
// TODO(jblespiau): Use GetNumpyScalarTypes instead.
|
||||
static const auto* int64_dt = new py::dtype("int64");
|
||||
static const auto* int32_dt = new py::dtype("int32");
|
||||
static const auto* uint64_dt = new py::dtype("uint64");
|
||||
@ -318,127 +369,246 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
|
||||
device));
|
||||
}
|
||||
|
||||
// Convert a scalar to the associated PjRtBuffer or raises an error if it is
|
||||
// not convertible (thus, this must be called after other checks).
|
||||
StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
|
||||
py::handle scalar, bool jax_enable_x64, xla::PjRtClient* client,
|
||||
xla::PjRtDevice* device) {
|
||||
// Important: In Python, isinstance(True, int) returns True. Thus, we have
|
||||
// to check for bool before int.
|
||||
if (py::isinstance<py::bool_>(scalar)) {
|
||||
return ConvertToScalarBuffer<bool, py::bool_>(scalar, client, device);
|
||||
} else if (py::isinstance<py::int_>(scalar)) {
|
||||
if (jax_enable_x64) {
|
||||
return ConvertToScalarBuffer<int64, py::int_>(scalar, client, device);
|
||||
} else {
|
||||
return ConvertToScalarBuffer<int, py::int_>(scalar, client, device);
|
||||
}
|
||||
} else if (py::isinstance<py::float_>(scalar)) {
|
||||
if (jax_enable_x64) {
|
||||
return ConvertToScalarBuffer<double, py::float_>(scalar, client, device);
|
||||
|
||||
} else {
|
||||
return ConvertToScalarBuffer<float, py::float_>(scalar, client, device);
|
||||
}
|
||||
} else if (PyComplex_Check(scalar.ptr())) {
|
||||
Py_complex result = PyComplex_AsCComplex(scalar.ptr());
|
||||
if (result.real == -1.0 && PyErr_Occurred()) {
|
||||
PyErr_Clear();
|
||||
throw std::runtime_error("Could not convert the complex number");
|
||||
}
|
||||
if (jax_enable_x64) {
|
||||
xla::complex128 data(result.real, result.imag);
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
|
||||
return ValueOrThrow(client->BufferFromHostBuffer(
|
||||
&data, shape,
|
||||
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||
nullptr, device));
|
||||
} else {
|
||||
xla::complex64 data(result.real, result.imag);
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
|
||||
return ValueOrThrow(client->BufferFromHostBuffer(
|
||||
&data, shape,
|
||||
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||
nullptr, device));
|
||||
}
|
||||
}
|
||||
return InvalidArgument(
|
||||
"%s", absl::StrCat(
|
||||
"Not supported: The C++ jax jit execution path, only accepts "
|
||||
"DeviceArray, Numpy arrays, or Python scalars. Got type ",
|
||||
py::cast<std::string>(py::str(scalar.get_type()))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<DevicePutResult> DevicePut(pybind11::handle obj, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
static const auto* xla_module =
|
||||
new py::module(py::module::import("jax.interpreters.xla"));
|
||||
const auto& device_array = xla_module->attr("_DeviceArray");
|
||||
namespace {
|
||||
using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
|
||||
py::handle, PjRtDevice*, bool, xla::PyClient&)>;
|
||||
|
||||
static const auto* numpy_module = new py::module(py::module::import("numpy"));
|
||||
const auto& np_array = numpy_module->attr("array");
|
||||
DevicePutResult HandleBool(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||
return DevicePutResult(ConvertToScalarBuffer<bool, py::bool_>(
|
||||
h, pyclient.pjrt_client(), to_device),
|
||||
/*weak_type=*/true);
|
||||
}
|
||||
|
||||
bool is_py_buffer = py::isinstance<PyBuffer>(obj);
|
||||
if (is_py_buffer) {
|
||||
// PyBuffer necessarily has a trivial LazyExpr, no need to check it.
|
||||
PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj);
|
||||
bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
|
||||
if (buffer->device().contents == to_device) {
|
||||
return DevicePutResult(buffer->buffer(), weak_type);
|
||||
} else {
|
||||
// Performs a device-to-device copy if the devices are on the same
|
||||
// platform.
|
||||
// Buffers from different XLA backends are passed through the host.
|
||||
std::unique_ptr<PjRtBuffer> copied_buffer =
|
||||
ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
|
||||
return DevicePutResult(std::move(copied_buffer), weak_type);
|
||||
}
|
||||
DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device,
|
||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||
if (jax_enable_x64) {
|
||||
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
|
||||
obj, pyclient.pjrt_client(), to_device),
|
||||
/*weak_type=*/true);
|
||||
} else {
|
||||
return DevicePutResult(ConvertToScalarBuffer<int, py::int_>(
|
||||
obj, pyclient.pjrt_client(), to_device),
|
||||
/*weak_type=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
} else if (obj.get_type().is(device_array)) {
|
||||
if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
|
||||
return InvalidArgument(
|
||||
"Non-trivial lazy expression not supported in C++. "
|
||||
"Falling back to Python.");
|
||||
}
|
||||
PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
|
||||
bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
|
||||
// Same block as in the previous `if (is_py_buffer)`.
|
||||
if (buffer->device().contents == to_device) {
|
||||
return DevicePutResult(buffer->buffer(), weak_type);
|
||||
} else {
|
||||
std::unique_ptr<PjRtBuffer> copied_buffer =
|
||||
ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
|
||||
return DevicePutResult(std::move(copied_buffer), weak_type);
|
||||
}
|
||||
} else if (py::isinstance<py::array>(obj)) {
|
||||
py::array numpy_array = py::cast<py::array>(obj);
|
||||
if (IsFloat0(numpy_array)) {
|
||||
return InvalidArgument(
|
||||
"float0 numpy arrays not supported in C++. "
|
||||
"Falling back to Python.");
|
||||
}
|
||||
// If jax_enable_x64 is not set, we need to coerce 32 bits types.
|
||||
// Note that this is calling back to Python!
|
||||
if (!jax_enable_x64) {
|
||||
const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
|
||||
if (to_dtype) {
|
||||
numpy_array = np_array(numpy_array, *to_dtype);
|
||||
}
|
||||
}
|
||||
template <bool weak_type>
|
||||
StatusOr<DevicePutResult> HandleFloat(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
if (jax_enable_x64) {
|
||||
return DevicePutResult(ConvertToScalarBuffer<double, py::float_>(
|
||||
h, pyclient.pjrt_client(), to_device),
|
||||
/*weak_type=*/weak_type);
|
||||
} else {
|
||||
return DevicePutResult(ConvertToScalarBuffer<float, py::float_>(
|
||||
h, pyclient.pjrt_client(), to_device),
|
||||
/*weak_type=*/weak_type);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool weak_type>
|
||||
StatusOr<DevicePutResult> HandleComplex(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
// This branch is also taken for np.complex128:
|
||||
// isinstance(np.complex128(3), complex) returns True
|
||||
// isinstance(np.complex64(3), complex) returns False
|
||||
Py_complex result = PyComplex_AsCComplex(h.ptr());
|
||||
if (result.real == -1.0 && PyErr_Occurred()) {
|
||||
PyErr_Clear();
|
||||
throw std::runtime_error("Could not convert the complex number");
|
||||
}
|
||||
if (jax_enable_x64) {
|
||||
xla::complex128 data(result.real, result.imag);
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
|
||||
return DevicePutResult(
|
||||
ValueOrThrow(pyclient.pjrt_client()->BufferFromHostBuffer(
|
||||
&data, shape,
|
||||
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||
nullptr, to_device)),
|
||||
/*weak_type=*/weak_type);
|
||||
} else {
|
||||
xla::complex64 data(result.real, result.imag);
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
|
||||
return DevicePutResult(
|
||||
ValueOrThrow(pyclient.pjrt_client()->BufferFromHostBuffer(
|
||||
&data, shape,
|
||||
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||
nullptr, to_device)),
|
||||
/*weak_type=*/weak_type);
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
|
||||
PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
|
||||
return InvalidArgument(
|
||||
"Non-trivial lazy expression not supported in C++. "
|
||||
"Falling back to Python.");
|
||||
}
|
||||
PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
|
||||
bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
|
||||
// Same block as in the previous `if (is_py_buffer)`.
|
||||
if (buffer->device().contents == to_device) {
|
||||
return DevicePutResult(buffer->buffer(), weak_type);
|
||||
} else {
|
||||
std::unique_ptr<PjRtBuffer> copied_buffer =
|
||||
ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
|
||||
return DevicePutResult(std::move(copied_buffer), weak_type);
|
||||
}
|
||||
}
|
||||
|
||||
// Do not convert types, and only call PjRtBufferFromPyval, independently
|
||||
// of the value of jax_enable_x64.
|
||||
DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
std::unique_ptr<xla::PjRtBuffer> buffer =
|
||||
ValueOrThrow(pyclient.PjRtBufferFromPyval(
|
||||
h, to_device,
|
||||
/*force_copy=*/false, /*host_buffer_semantics=*/
|
||||
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
|
||||
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
|
||||
}
|
||||
|
||||
DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||
if (jax_enable_x64) {
|
||||
return DevicePutResult(ConvertToScalarBuffer<int64, py::int_>(
|
||||
h, pyclient.pjrt_client(), to_device),
|
||||
/*weak_type=*/false);
|
||||
} else {
|
||||
return DevicePutResult(ConvertToScalarBuffer<int, py::int_>(
|
||||
h, pyclient.pjrt_client(), to_device),
|
||||
/*weak_type=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64, xla::PyClient& pyclient) {
|
||||
if (jax_enable_x64) {
|
||||
std::unique_ptr<xla::PjRtBuffer> buffer =
|
||||
ValueOrThrow(pyclient.PjRtBufferFromPyval(
|
||||
numpy_array, to_device,
|
||||
h, to_device,
|
||||
/*force_copy=*/false, /*host_buffer_semantics=*/
|
||||
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
|
||||
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<xla::PjRtBuffer> buffer,
|
||||
ScalarToBuffer(obj, jax_enable_x64, to_device->client(), to_device));
|
||||
return DevicePutResult(std::move(buffer), /*weak_type=*/true);
|
||||
static const auto* numpy = new py::module(py::module::import("numpy"));
|
||||
const auto& np_array = numpy->attr("array");
|
||||
|
||||
// Note that this is calling back to Python!
|
||||
std::unique_ptr<xla::PjRtBuffer> buffer =
|
||||
ValueOrThrow(pyclient.PjRtBufferFromPyval(
|
||||
np_array(h, py::dtype("uint32")), to_device,
|
||||
/*force_copy=*/false, /*host_buffer_semantics=*/
|
||||
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
|
||||
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<DevicePutResult> HandleNdarray(py::handle h, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
py::array numpy_array = py::cast<py::array>(h);
|
||||
if (IsFloat0(numpy_array)) {
|
||||
return InvalidArgument("%s",
|
||||
"float0 numpy arrays not supported in C++. "
|
||||
"Falling back to Python.");
|
||||
}
|
||||
// If jax_enable_x64 is not set, we need to coerce 32 bits types.
|
||||
// Note that this is calling back to Python!
|
||||
if (!jax_enable_x64) {
|
||||
const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
|
||||
if (to_dtype) {
|
||||
static const auto* numpy = new py::module(py::module::import("numpy"));
|
||||
const auto& np_array = numpy->attr("array");
|
||||
numpy_array = np_array(numpy_array, *to_dtype);
|
||||
}
|
||||
}
|
||||
std::unique_ptr<xla::PjRtBuffer> buffer =
|
||||
ValueOrThrow(pyclient.PjRtBufferFromPyval(
|
||||
numpy_array, to_device,
|
||||
/*force_copy=*/false, /*host_buffer_semantics=*/
|
||||
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
|
||||
return DevicePutResult(std::move(buffer), /*weak_type=*/false);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
||||
bool jax_enable_x64,
|
||||
xla::PyClient& pyclient) {
|
||||
static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
|
||||
[] {
|
||||
auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
|
||||
|
||||
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
|
||||
|
||||
const auto numpy = py::module::import("numpy");
|
||||
const auto xla_module = py::module::import("jax.interpreters.xla");
|
||||
const auto& device_array = xla_module.attr("_DeviceArray");
|
||||
|
||||
// Python base types.
|
||||
(*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = HandleBool;
|
||||
(*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandleInt;
|
||||
(*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = HandleFloat<true>;
|
||||
(*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
|
||||
HandleComplex<true>;
|
||||
|
||||
// DeviceArray and co.
|
||||
const auto pxla_module = py::module::import("jax.interpreters.pxla");
|
||||
const auto& sda = pxla_module.attr("ShardedDeviceArray");
|
||||
(*p)[device_array.ptr()] = HandleDeviceArray;
|
||||
(*p)[py::type::handle_of<DeviceArrayBase>().ptr()] = HandleDeviceArray;
|
||||
(*p)[sda.ptr()] = HandleBufferFromPyval;
|
||||
// Numpy arrays.
|
||||
(*p)[numpy.attr("ndarray").ptr()] = HandleNdarray;
|
||||
|
||||
// Numpy scalar types. For some of them, we share the handler with
|
||||
// Python types (np_int64, np_float64, np_complex128).
|
||||
(*p)[dtypes.np_bool.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_int8.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_int16.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_int32.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_int64.ptr()] = HandleNpBool;
|
||||
(*p)[dtypes.np_uint8.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_uint16.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_uint32.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_uint64.ptr()] = HandleUint64;
|
||||
(*p)[dtypes.np_float16.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_float32.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_float64.ptr()] = HandleFloat<false>;
|
||||
(*p)[dtypes.np_complex64.ptr()] = HandleBufferFromPyval;
|
||||
(*p)[dtypes.np_complex128.ptr()] = HandleComplex<false>;
|
||||
(*p)[dtypes.np_longlong.ptr()] = HandleNpBool;
|
||||
(*p)[dtypes.np_intc.ptr()] = HandleBufferFromPyval;
|
||||
|
||||
return p;
|
||||
}();
|
||||
|
||||
auto res = handlers->find(arg.get_type().ptr());
|
||||
if (res == handlers->end()) {
|
||||
for (auto base_class : arg.get_type().attr("mro")()) {
|
||||
res = handlers->find(base_class.ptr());
|
||||
if (res != handlers->end()) {
|
||||
return res->second(arg, to_device, jax_enable_x64, pyclient);
|
||||
}
|
||||
}
|
||||
return InvalidArgument(
|
||||
"%s", absl::StrCat(
|
||||
"Not supported: The C++ jax jit execution path, only accepts "
|
||||
"DeviceArray, Numpy arrays scalars of supported types "
|
||||
"(see implementation), or Python scalars. Got type ",
|
||||
py::cast<std::string>(py::str(arg.get_type()))));
|
||||
} else {
|
||||
return res->second(arg, to_device, jax_enable_x64, pyclient);
|
||||
}
|
||||
}
|
||||
|
||||
@ -857,7 +1027,37 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
||||
std::move(static_argnums));
|
||||
});
|
||||
|
||||
// Only for testing purposes
|
||||
// This function is yet a full replacement for the Python one, because:
|
||||
// (a) it does not support abstract types,
|
||||
// (b) it does not set the device stickiness yet.
|
||||
// TODO(jblespiau): Finish the replacement of the Python feature.
|
||||
jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64,
|
||||
ClientAndPtr<PjRtDevice> to_device) {
|
||||
std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
|
||||
StatusOr<DevicePutResult> results =
|
||||
DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient);
|
||||
if (!results.ok()) {
|
||||
throw std::runtime_error(results.status().error_message());
|
||||
}
|
||||
if (results->owned_buffer) {
|
||||
auto buffer = std::make_unique<PyBuffer>(
|
||||
pyclient, std::move(results->owned_buffer), Traceback::Get());
|
||||
|
||||
static const auto* jax_core =
|
||||
new py::module(py::module::import("jax.core"));
|
||||
static const auto* shaped_array =
|
||||
new py::handle(jax_core->attr("ShapedArray"));
|
||||
buffer->SetAval((*shaped_array)(
|
||||
buffer->python_shape(), buffer->python_dtype(), results->weak_type));
|
||||
buffer->SetStickyDevice(py::none());
|
||||
|
||||
return py::cast(std::move(buffer));
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
|
||||
// All private members are only for testing purposes
|
||||
cfun.def("_cache_size", &CompiledFunction::cache_size);
|
||||
jitlib.def("_DtypeTo32BitDtype", [](const py::object obj) -> py::object {
|
||||
py::dtype dtype = py::dtype::from_args(obj);
|
||||
@ -870,17 +1070,6 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
||||
});
|
||||
jitlib.def("_is_float0", &IsFloat0);
|
||||
jitlib.def("_is_trivial", &IsTrivialLazyExpr);
|
||||
jitlib.def("_ScalarToBuffer", [](py::handle scalar, bool jax_enable_x64,
|
||||
std::shared_ptr<xla::PyClient> client) {
|
||||
xla::PjRtClient* pjrt_client = client->pjrt_client();
|
||||
|
||||
return std::make_unique<xla::PyBuffer>(
|
||||
client,
|
||||
ScalarToBuffer(scalar, jax_enable_x64, pjrt_client,
|
||||
pjrt_client->local_devices()[0])
|
||||
.ValueOrDie(),
|
||||
nullptr);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -159,7 +159,7 @@ struct DevicePutResult {
|
||||
// If `obj` is not convertible to a `PjRtBuffer` from C++, an error will be
|
||||
// returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not
|
||||
// supported yet.
|
||||
StatusOr<DevicePutResult> DevicePut(pybind11::handle obj, PjRtDevice* to_device,
|
||||
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
|
||||
bool jax_enable_x64, PyClient& pyclient);
|
||||
|
||||
// The function to call in `xla.cc` to add the bindings for this module.
|
||||
|
@ -52,6 +52,15 @@ PyBuffer::~PyBuffer() {
|
||||
}
|
||||
}
|
||||
|
||||
pybind11::tuple PyBuffer::python_shape() const {
|
||||
return IntSpanToTuple(buffer()->on_host_shape().dimensions());
|
||||
}
|
||||
|
||||
pybind11::dtype PyBuffer::python_dtype() const {
|
||||
PrimitiveType primitive = buffer()->on_host_shape().element_type();
|
||||
return PrimitiveTypeToDtype(primitive).ValueOrDie();
|
||||
}
|
||||
|
||||
ClientAndPtr<PjRtDevice> PyBuffer::device() const {
|
||||
return WrapWithClient(client_, buffer_->device());
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/python/py_client.h"
|
||||
#include "tensorflow/compiler/xla/python/traceback.h"
|
||||
@ -95,6 +96,9 @@ class PyBuffer : public DeviceArrayBase {
|
||||
// Returns the number of dimensions of the (host) numpy array.
|
||||
int ndim() const { return buffer()->on_host_shape().dimensions_size(); }
|
||||
|
||||
pybind11::tuple python_shape() const;
|
||||
pybind11::dtype python_dtype() const;
|
||||
|
||||
void SetStickyDevice(pybind11::object sticky_device);
|
||||
pybind11::object GetStickyDevice() const { return sticky_device_.value(); }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user