From 10c06df9492c44062f3a6b0c04bec325c7738426 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Sat, 16 Jan 2021 15:06:30 -0800 Subject: [PATCH] Move jax_jit functions to the `jax` namespace. PiperOrigin-RevId: 352208505 Change-Id: Iff1ac8e432f13d41f71b41b6e560c35765549043 --- tensorflow/compiler/xla/python/jax_jit.cc | 204 ++++++++++++---------- tensorflow/compiler/xla/python/jax_jit.h | 42 ++--- tensorflow/compiler/xla/python/xla.cc | 2 +- 3 files changed, 131 insertions(+), 117 deletions(-) diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 3ba1a5b32a9..2de272b7fbf 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -53,7 +53,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/status.h" -namespace xla { +namespace jax { namespace py = pybind11; @@ -180,11 +180,11 @@ H AbslHashValue(H h, const CallSignature& s) { // Filter out static arguments, flatten and concatenate other arguments (i.e. // dynamic positional and keyword arguments), filling `arguments` in place. -Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs, - absl::Span static_argnums, - ParsedArgumentsAsBuffers& arguments) { +xla::Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs, + absl::Span static_argnums, + ParsedArgumentsAsBuffers& arguments) { if (static_argnums.size() > args.size()) { - return InvalidArgument( + return xla::InvalidArgument( "%s", "[jaxjit] Error with static argnums, executing the Python path."); } arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() - @@ -196,7 +196,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs, for (size_t i = 0; i < args.size(); ++i) { if (std::find(static_argnums.begin(), static_argnums.end(), i) == static_argnums.end()) { - PyTreeDef pytree_def; + xla::PyTreeDef pytree_def; pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args); arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def); } else { @@ -210,7 +210,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs, std::vector> kwargs(py_kwargs.begin(), py_kwargs.end()); // We first intern the keys, then sort them (by name, as in the Python path) - // (see also PyTreeDef::Flatten) and then create the signatures. + // (see also xla::PyTreeDef::Flatten) and then create the signatures. // TODO(jblespiau): We should be able to sort the keys by interned-key // pointers, but this requires the Python compilation to do the same. arguments.signature.keyword_args.resize(kwargs.size()); @@ -236,7 +236,7 @@ Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs, arguments.signature.keyword_args[i].value_treedef.FlattenInto( kwargs[i].second, arguments.flat_dynamic_args); } - return Status::OK(); + return xla::Status::OK(); } namespace { @@ -374,11 +374,11 @@ std::unique_ptr ConvertToScalarBuffer( namespace { using ToArgSignatureHandler = - std::function(py::handle, bool)>; + std::function(py::handle, bool)>; } -StatusOr ArgSignatureOfValue(pybind11::handle arg, - bool jax_enable_x64) { +xla::StatusOr ArgSignatureOfValue(pybind11::handle arg, + bool jax_enable_x64) { static const absl::flat_hash_map* const handlers = [] { auto p = new absl::flat_hash_map(); @@ -389,38 +389,40 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); // The 4 Python native types. - ToArgSignatureHandler bool_handler = [](py::handle, - bool) -> StatusOr { - return ArgSignature(PrimitiveType::PRED, {}, true); + ToArgSignatureHandler bool_handler = + [](py::handle, bool) -> xla::StatusOr { + return ArgSignature(xla::PrimitiveType::PRED, {}, true); }; ToArgSignatureHandler int_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](py::handle h, bool jax_enable_x64) -> xla::StatusOr { if (jax_enable_x64) { - return ArgSignature(PrimitiveType::S64, {}, true); + return ArgSignature(xla::PrimitiveType::S64, {}, true); } else { - return ArgSignature(PrimitiveType::S32, {}, true); + return ArgSignature(xla::PrimitiveType::S32, {}, true); } }; ToArgSignatureHandler float_handler = - [&dtypes](py::handle h, bool jax_enable_x64) -> StatusOr { + [&dtypes](py::handle h, + bool jax_enable_x64) -> xla::StatusOr { // Only Python native types has a True weak_type. bool weak_type = !py::isinstance(h, dtypes.np_float64); if (jax_enable_x64) { - return ArgSignature(PrimitiveType::F64, {}, weak_type); + return ArgSignature(xla::PrimitiveType::F64, {}, weak_type); } else { - return ArgSignature(PrimitiveType::F32, {}, weak_type); + return ArgSignature(xla::PrimitiveType::F32, {}, weak_type); } }; ToArgSignatureHandler complex_handler = - [&dtypes](py::handle h, bool jax_enable_x64) -> StatusOr { + [&dtypes](py::handle h, + bool jax_enable_x64) -> xla::StatusOr { // Note that this branch is also taken for np.complex128: // isinstance(np.complex128(3), complex) returns True // isinstance(np.complex64(3), complex) returns False bool weak_type = !py::isinstance(h, dtypes.np_complex128); if (jax_enable_x64) { - return ArgSignature(PrimitiveType::C128, {}, weak_type); + return ArgSignature(xla::PrimitiveType::C128, {}, weak_type); } else { - return ArgSignature(PrimitiveType::C64, {}, weak_type); + return ArgSignature(xla::PrimitiveType::C64, {}, weak_type); } }; @@ -432,18 +434,19 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, // The Buffer types // PyBuffer necessarily has a trivial LazyExpr, no need to check it. ToArgSignatureHandler buffer_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { - PyBuffer* buffer = py::cast(h); + [](py::handle h, bool jax_enable_x64) -> xla::StatusOr { + xla::PyBuffer* buffer = py::cast(h); bool weak_type = py::cast(h.attr("aval").attr("weak_type")); return ArgSignature(buffer->buffer()->on_host_shape().element_type(), buffer->buffer()->on_host_shape().dimensions(), weak_type); }; - (*p)[py::type::handle_of().ptr()] = buffer_handler; + (*p)[py::type::handle_of().ptr()] = buffer_handler; ToArgSignatureHandler device_array_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](py::handle h, bool jax_enable_x64) -> xla::StatusOr { py::handle aval = h.attr("aval"); - TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(aval.attr("dtype"))); + TF_ASSIGN_OR_RETURN(auto dtype, + xla::DtypeToPrimitiveType(aval.attr("dtype"))); return ArgSignature(dtype, py::cast>(aval.attr("shape")), py::cast(aval.attr("weak_type"))); @@ -452,10 +455,10 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, (*p)[device_array.ptr()] = device_array_handler; ToArgSignatureHandler numpy_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](py::handle h, bool jax_enable_x64) -> xla::StatusOr { py::array numpy_array = py::cast(h); if (IsFloat0(numpy_array)) { - return InvalidArgument( + return xla::InvalidArgument( "float0 numpy arrays not supported in C++. " "Falling back to Python."); } @@ -463,11 +466,11 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, const py::dtype raw_dtype = numpy_array.dtype(); const py::dtype* to_dtype = DtypeTo32BitDtype(raw_dtype); - PrimitiveType dtype; + xla::PrimitiveType dtype; if (to_dtype) { - TF_ASSIGN_OR_RETURN(dtype, DtypeToPrimitiveType(*to_dtype)); + TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(*to_dtype)); } else { - TF_ASSIGN_OR_RETURN(dtype, DtypeToPrimitiveType(raw_dtype)); + TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(raw_dtype)); } // We need the reinterpret_cast for the OSS version to build. return ArgSignature(dtype, @@ -477,7 +480,7 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, /*weak_type=*/false); } TF_ASSIGN_OR_RETURN(auto dtype, - DtypeToPrimitiveType(numpy_array.dtype())); + xla::DtypeToPrimitiveType(numpy_array.dtype())); return ArgSignature(dtype, absl::MakeConstSpan(reinterpret_cast( numpy_array.shape()), @@ -489,27 +492,28 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, (*p)[ndarray.ptr()] = numpy_handler; ToArgSignatureHandler np_uint64_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](py::handle h, bool jax_enable_x64) -> xla::StatusOr { if (jax_enable_x64) { - return ArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false); + return ArgSignature(xla::PrimitiveType::U64, {}, /*weak_type=*/false); } else { - return ArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false); + return ArgSignature(xla::PrimitiveType::U32, {}, /*weak_type=*/false); } }; ToArgSignatureHandler np_int_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](py::handle h, bool jax_enable_x64) -> xla::StatusOr { if (jax_enable_x64) { - return ArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false); + return ArgSignature(xla::PrimitiveType::S64, {}, /*weak_type=*/false); } else { - return ArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false); + return ArgSignature(xla::PrimitiveType::S32, {}, /*weak_type=*/false); } }; ToArgSignatureHandler numpy_array_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](py::handle h, bool jax_enable_x64) -> xla::StatusOr { // This block deals with all numpy scalar types, except for int64_dt, // float64_dt and complex128_dt which are taken care of in previous if // blocks. - TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(h.attr("dtype"))); + TF_ASSIGN_OR_RETURN(auto dtype, + xla::DtypeToPrimitiveType(h.attr("dtype"))); return ArgSignature(dtype, {}, /*weak_type=*/false); }; @@ -545,7 +549,7 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, return res->second(arg, jax_enable_x64); } } - return InvalidArgument( + return xla::InvalidArgument( "%s", absl::StrCat("Not supported: The C++ ToArgSignature only accepts " "Buffer/DeviceArray/ShardedDeviceArray, Numpy " "arrays scalars of supported types " @@ -557,17 +561,17 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, } namespace { -using DevicePutFunc = std::function( - py::handle, PjRtDevice*, bool, xla::PyClient&)>; +using DevicePutFunc = std::function( + py::handle, xla::PjRtDevice*, bool, xla::PyClient&)>; -DevicePutResult HandleBool(py::handle h, PjRtDevice* to_device, +DevicePutResult HandleBool(py::handle h, xla::PjRtDevice* to_device, bool jax_enable_x64, xla::PyClient& pyclient) { return DevicePutResult(ConvertToScalarBuffer( h, pyclient.pjrt_client(), to_device), /*weak_type=*/true); } -DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device, +DevicePutResult HandleInt(py::handle obj, xla::PjRtDevice* to_device, bool jax_enable_x64, xla::PyClient& pyclient) { if (jax_enable_x64) { return DevicePutResult(ConvertToScalarBuffer( @@ -581,9 +585,10 @@ DevicePutResult HandleInt(py::handle obj, PjRtDevice* to_device, } template -StatusOr HandleFloat(py::handle h, PjRtDevice* to_device, - bool jax_enable_x64, - xla::PyClient& pyclient) { +xla::StatusOr HandleFloat(py::handle h, + xla::PjRtDevice* to_device, + bool jax_enable_x64, + xla::PyClient& pyclient) { if (jax_enable_x64) { return DevicePutResult(ConvertToScalarBuffer( h, pyclient.pjrt_client(), to_device), @@ -596,9 +601,10 @@ StatusOr HandleFloat(py::handle h, PjRtDevice* to_device, } template -StatusOr HandleComplex(py::handle h, PjRtDevice* to_device, - bool jax_enable_x64, - xla::PyClient& pyclient) { +xla::StatusOr HandleComplex(py::handle h, + xla::PjRtDevice* to_device, + bool jax_enable_x64, + xla::PyClient& pyclient) { // This branch is also taken for np.complex128: // isinstance(np.complex128(3), complex) returns True // isinstance(np.complex64(3), complex) returns False @@ -628,22 +634,22 @@ StatusOr HandleComplex(py::handle h, PjRtDevice* to_device, } } -StatusOr HandleDeviceArray(py::handle obj, - PjRtDevice* to_device, - bool jax_enable_x64, - xla::PyClient& pyclient) { +xla::StatusOr HandleDeviceArray(py::handle obj, + xla::PjRtDevice* to_device, + bool jax_enable_x64, + xla::PyClient& pyclient) { if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) { - return InvalidArgument( + return xla::InvalidArgument( "Non-trivial lazy expression not supported in C++. " "Falling back to Python."); } - PyBuffer* buffer = py::cast(obj.attr("device_buffer")); + xla::PyBuffer* buffer = py::cast(obj.attr("device_buffer")); bool weak_type = py::cast(obj.attr("aval").attr("weak_type")); // Same block as in the previous `if (is_py_buffer)`. if (buffer->device().contents == to_device) { return DevicePutResult(buffer->buffer(), weak_type); } else { - std::unique_ptr copied_buffer = + std::unique_ptr copied_buffer = ValueOrThrow(buffer->buffer()->CopyToDevice(to_device)); return DevicePutResult(std::move(copied_buffer), weak_type); } @@ -651,7 +657,7 @@ StatusOr HandleDeviceArray(py::handle obj, // Do not convert types, and only call PjRtBufferFromPyval, independently // of the value of jax_enable_x64. -DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device, +DevicePutResult HandleBufferFromPyval(py::handle h, xla::PjRtDevice* to_device, bool jax_enable_x64, xla::PyClient& pyclient) { std::unique_ptr buffer = @@ -662,7 +668,7 @@ DevicePutResult HandleBufferFromPyval(py::handle h, PjRtDevice* to_device, return DevicePutResult(std::move(buffer), /*weak_type=*/false); } -DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device, +DevicePutResult HandleNpBool(py::handle h, xla::PjRtDevice* to_device, bool jax_enable_x64, xla::PyClient& pyclient) { if (jax_enable_x64) { return DevicePutResult(ConvertToScalarBuffer( @@ -675,7 +681,7 @@ DevicePutResult HandleNpBool(py::handle h, PjRtDevice* to_device, } } -DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device, +DevicePutResult HandleUint64(py::handle h, xla::PjRtDevice* to_device, bool jax_enable_x64, xla::PyClient& pyclient) { if (jax_enable_x64) { std::unique_ptr buffer = @@ -698,14 +704,15 @@ DevicePutResult HandleUint64(py::handle h, PjRtDevice* to_device, } } -StatusOr HandleNdarray(py::handle h, PjRtDevice* to_device, - bool jax_enable_x64, - xla::PyClient& pyclient) { +xla::StatusOr HandleNdarray(py::handle h, + xla::PjRtDevice* to_device, + bool jax_enable_x64, + xla::PyClient& pyclient) { py::array numpy_array = py::cast(h); if (IsFloat0(numpy_array)) { - return InvalidArgument("%s", - "float0 numpy arrays not supported in C++. " - "Falling back to Python."); + return xla::InvalidArgument("%s", + "float0 numpy arrays not supported in C++. " + "Falling back to Python."); } // If jax_enable_x64 is not set, we need to coerce 32 bits types. // Note that this is calling back to Python! @@ -727,9 +734,10 @@ StatusOr HandleNdarray(py::handle h, PjRtDevice* to_device, } // namespace -StatusOr DevicePut(pybind11::handle arg, PjRtDevice* to_device, - bool jax_enable_x64, - xla::PyClient& pyclient) { +xla::StatusOr DevicePut(pybind11::handle arg, + xla::PjRtDevice* to_device, + bool jax_enable_x64, + xla::PyClient& pyclient) { static const absl::flat_hash_map* const handlers = [] { auto p = new absl::flat_hash_map(); @@ -751,7 +759,8 @@ StatusOr DevicePut(pybind11::handle arg, PjRtDevice* to_device, const auto pxla_module = py::module::import("jax.interpreters.pxla"); const auto& sda = pxla_module.attr("ShardedDeviceArray"); (*p)[device_array.ptr()] = HandleDeviceArray; - (*p)[py::type::handle_of().ptr()] = HandleDeviceArray; + (*p)[py::type::handle_of().ptr()] = + HandleDeviceArray; (*p)[sda.ptr()] = HandleBufferFromPyval; // Numpy arrays. (*p)[numpy.attr("ndarray").ptr()] = HandleNdarray; @@ -786,7 +795,7 @@ StatusOr DevicePut(pybind11::handle arg, PjRtDevice* to_device, return res->second(arg, to_device, jax_enable_x64, pyclient); } } - return InvalidArgument( + return xla::InvalidArgument( "%s", absl::StrCat( "Not supported: The C++ jax jit execution path, only accepts " "DeviceArray, Numpy arrays scalars of supported types " @@ -801,7 +810,7 @@ namespace { struct CacheEntry { std::shared_ptr executable; - PyTreeDef out_pytree_def; + xla::PyTreeDef out_pytree_def; // We use Python types within the vector because this is what we will be // returning to Python. No need to convert back and forth. // We need py::object to maintain the objects alive. @@ -817,7 +826,7 @@ struct CacheEntry { // a signature and if the object has been insterted already, other threads // will wait for the notification. absl::Notification compilation_complete; - absl::optional compilation_error = absl::nullopt; + absl::optional compilation_error = absl::nullopt; // Trivial computation will fallback to Python. // Running a jax(pmap) will also fallback to Python. bool fall_back_to_python = false; @@ -896,7 +905,7 @@ class CompiledFunction { // the `default_device_` which will be used as the targeted device. In // which case, we will always copy input buffers to this device. std::shared_ptr default_pyclient_ = nullptr; - xla::ClientAndPtr default_pydevice_; + xla::ClientAndPtr default_pydevice_; xla::PjRtDevice* default_device_ = nullptr; bool is_committed_; }; @@ -924,11 +933,12 @@ CompiledFunction::~CompiledFunction() { // Converts flattened arguments contained in ParsedArgumentsAsBuffers in // place. If arguments are `DeviceArray`, they must all be on the same `Device`. // -// Returns `OkStatus()` on success. Returning an error should lead to calling -// the Python fallback. -Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, - xla::PjRtDevice* default_device, bool is_committed, - ParsedArgumentsAsBuffers& arguments) { +// Returns `Okxla::Status()` on success. Returning an error should lead to +// calling the Python fallback. +xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, + xla::PjRtDevice* default_device, + bool is_committed, + ParsedArgumentsAsBuffers& arguments) { std::vector& arg_buffers = arguments.arg_buffers; auto& keep_alive = arguments.keep_alive; @@ -953,7 +963,8 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, for (py::handle arg : arguments.flat_dynamic_args) { // We specically only deal with DeviceArray (not ShardedDeviceArray). // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored"). - if (py::isinstance(arg) || arg.get_type().is(device_array)) { + if (py::isinstance(arg) || + arg.get_type().is(device_array)) { xla::PyBuffer* buffer; if (arg.attr("_device").is_none()) { // Skip non-sticky devices. continue; @@ -962,7 +973,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, // This can fail, e.g. when device_buffer is a `DeviceConstant`. buffer = py::cast(arg.attr("device_buffer")); } catch (const py::cast_error& e) { - return InvalidArgument( + return xla::InvalidArgument( "%s", absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: " "`device_buffer` field is of type ", @@ -995,7 +1006,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, TF_ASSIGN_OR_RETURN(DevicePutResult on_device, DevicePut(arg, data_device, jax_enable_x64, pyclient)); - PjRtBuffer* buffer = on_device.buffer; + xla::PjRtBuffer* buffer = on_device.buffer; arg_buffers.push_back(buffer); if (on_device.owned_buffer) { keep_alive.emplace_back(std::move(on_device.owned_buffer)); @@ -1005,7 +1016,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, buffer->on_host_shape().dimensions(), on_device.weak_type); arguments.signature.dynamic_args_signatures.push_back(std::move(sig)); } - return Status::OK(); + return xla::Status::OK(); } CacheEntry* CompiledFunction::GetCacheEntryIfPresent( @@ -1063,7 +1074,7 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args, // The presence of jit(pmap) is detected from Python. CHECK_EQ(num_devices, 1); - auto out_tree = py::cast(executable_handlers_out_tree[1]); + auto out_tree = py::cast(executable_handlers_out_tree[1]); cache_entry->out_pytree_def = std::move(out_tree); cache_entry->sticky_device = @@ -1105,7 +1116,7 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { if (!default_device_) { py::object device_and_is_committed = get_device_(); try { - default_pydevice_ = py::cast>( + default_pydevice_ = py::cast>( device_and_is_committed.attr("default_device")); } catch (const py::cast_error& e) { // Pathways and Cloud TPU 2VM runtime. @@ -1217,16 +1228,16 @@ void BuildJaxjitSubmodule(pybind11::module& m) { // (b) it does not set the device stickiness yet. // TODO(jblespiau): Finish the replacement of the Python feature. jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64, - ClientAndPtr to_device) { + xla::ClientAndPtr to_device) { std::shared_ptr& pyclient = to_device.client; - StatusOr results = + xla::StatusOr results = DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient); if (!results.ok()) { throw std::runtime_error(results.status().error_message()); } if (results->owned_buffer) { - auto buffer = std::make_unique( - pyclient, std::move(results->owned_buffer), Traceback::Get()); + auto buffer = std::make_unique( + pyclient, std::move(results->owned_buffer), xla::Traceback::Get()); static const auto* jax_core = new py::module(py::module::import("jax.core")); @@ -1248,9 +1259,10 @@ void BuildJaxjitSubmodule(pybind11::module& m) { [](const ArgSignature& sig) { return PrimitiveTypeToDtype(sig.dtype); }) - .def_property_readonly( - "shape", - [](const ArgSignature& sig) { return IntSpanToTuple(sig.shape); }) + .def_property_readonly("shape", + [](const ArgSignature& sig) { + return xla::IntSpanToTuple(sig.shape); + }) .def_readonly("weak_type", &ArgSignature::weak_type); jitlib.def("_ArgSignatureOfValue", &ArgSignatureOfValue); @@ -1269,4 +1281,4 @@ void BuildJaxjitSubmodule(pybind11::module& m) { jitlib.def("_is_trivial", &IsTrivialLazyExpr); } -} // namespace xla +} // namespace jax diff --git a/tensorflow/compiler/xla/python/jax_jit.h b/tensorflow/compiler/xla/python/jax_jit.h index c5a6fe152f7..53a2d8aa778 100644 --- a/tensorflow/compiler/xla/python/jax_jit.h +++ b/tensorflow/compiler/xla/python/jax_jit.h @@ -24,15 +24,15 @@ limitations under the License. #include "tensorflow/compiler/xla/python/pytree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace xla { +namespace jax { // Describes the abstract shape and dtype of an argument. struct ArgSignature { - ArgSignature(PrimitiveType dtype, absl::Span shape, + ArgSignature(xla::PrimitiveType dtype, absl::Span shape, bool weak_type) : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {} // This is the XLA dtype of the object. - const PrimitiveType dtype; + const xla::PrimitiveType dtype; const absl::InlinedVector shape; // JAX arguments can be of weak type, if and only if they are Python scalars // or `DeviceArray` values such that `aval.weak_type` is true. @@ -66,7 +66,7 @@ struct CallSignature { // To avoid comparing strings, we intern the kwargs strings. // The compilation cache holds a reference to all the keys. pybind11::handle key; - PyTreeDef value_treedef; + xla::PyTreeDef value_treedef; bool operator==(const KwargEntry& other) const { return key.ptr() == other.key.ptr() && value_treedef == other.value_treedef; @@ -78,13 +78,13 @@ struct CallSignature { // order of their argnum index. std::vector static_args; // A PyTreeDef for each positional dynamic (i.e. not static) argument. - std::vector dynamic_positional_args_treedef; + std::vector dynamic_positional_args_treedef; // Keyword arguments. Sorted by the keyword name. std::vector keyword_args; // Shape and dtype for both the dynamic positional arguments and the keyword // arguments (sorted by keyword name). std::vector dynamic_args_signatures; - PjRtDevice* device; + xla::PjRtDevice* device; bool operator==(const CallSignature& other) const; bool operator!=(const CallSignature& other) const { @@ -132,28 +132,28 @@ struct ParsedArgumentsAsBuffers { // Filter out static arguments, flatten and concatenate other arguments (i.e. // dynamic positional and keyword arguments), filling `arguments` in place. -Status ParseArguments(const pybind11::args& args, - const pybind11::kwargs& py_kwargs, - absl::Span static_argnums, - ParsedArgumentsAsBuffers& arguments); +xla::Status ParseArguments(const pybind11::args& args, + const pybind11::kwargs& py_kwargs, + absl::Span static_argnums, + ParsedArgumentsAsBuffers& arguments); struct DevicePutResult { - explicit DevicePutResult(PjRtBuffer* b, bool weak_type) + explicit DevicePutResult(xla::PjRtBuffer* b, bool weak_type) : buffer(b), weak_type(weak_type), owned_buffer(nullptr) {} - DevicePutResult(std::unique_ptr new_buffer, bool weak_type) + DevicePutResult(std::unique_ptr new_buffer, bool weak_type) : buffer(new_buffer.get()), weak_type(weak_type), owned_buffer(std::move(new_buffer)) {} - PjRtBuffer* buffer; + xla::PjRtBuffer* buffer; bool weak_type; - std::unique_ptr owned_buffer; + std::unique_ptr owned_buffer; }; // Returns the ArgSignature associated with an argument. Returns an error if // the argument is not supported. -StatusOr ArgSignatureOfValue(pybind11::handle arg, - bool jax_enable_x64); +xla::StatusOr ArgSignatureOfValue(pybind11::handle arg, + bool jax_enable_x64); // Moves a device-like object to be on device. // - If the object is already on device, `owned_buffer` will be nullptr. @@ -161,15 +161,17 @@ StatusOr ArgSignatureOfValue(pybind11::handle arg, // `owned_buffer`. // In all cases, `buffer` will point to the already existing or newly created // buffer. -// If `obj` is not convertible to a `PjRtBuffer` from C++, an error will be +// If `obj` is not convertible to a `xla::PjRtBuffer` from C++, an error will be // returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not // supported yet. -StatusOr DevicePut(pybind11::handle arg, PjRtDevice* to_device, - bool jax_enable_x64, PyClient& pyclient); +xla::StatusOr DevicePut(pybind11::handle arg, + xla::PjRtDevice* to_device, + bool jax_enable_x64, + xla::PyClient& pyclient); // The function to call in `xla.cc` to add the bindings for this module. void BuildJaxjitSubmodule(pybind11::module& m); -} // namespace xla +} // namespace jax #endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 3049bdae335..320d078dc2c 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -415,7 +415,7 @@ PYBIND11_MODULE(xla_extension, m) { BuildOpsSubmodule(&m); BuildOutfeedReceiverSubmodule(&m); BuildPytreeSubmodule(m); - BuildJaxjitSubmodule(m); + jax::BuildJaxjitSubmodule(m); jax::BuildPmapSubmodule(m); BuildTracebackSubmodule(m);