Have the C++ path raise a cleaner error.
PiperOrigin-RevId: 347818682 Change-Id: I6a02aea7843761f229325fcb1b0d2078af242103
This commit is contained in:
parent
26f4c529e7
commit
76e17494d6
@ -178,9 +178,13 @@ H AbslHashValue(H h, const CallSignature& s) {
|
||||
|
||||
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
||||
// dynamic positional and keyword arguments), filling `arguments` in place.
|
||||
void ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
absl::Span<int const> static_argnums,
|
||||
ParsedArgumentsAsBuffers& arguments) {
|
||||
Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
absl::Span<int const> static_argnums,
|
||||
ParsedArgumentsAsBuffers& arguments) {
|
||||
if (static_argnums.size() > args.size()) {
|
||||
return InvalidArgument(
|
||||
"%s", "[jaxjit] Error with static argnums, executing the Python path.");
|
||||
}
|
||||
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
|
||||
static_argnums.size());
|
||||
arguments.signature.dynamic_positional_args_treedef.reserve(
|
||||
@ -230,6 +234,7 @@ void ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
|
||||
kwargs[i].second, arguments.flat_dynamic_args);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -764,7 +769,9 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
return fun_(*args, **kwargs);
|
||||
}
|
||||
ParsedArgumentsAsBuffers arguments;
|
||||
ParseArguments(args, kwargs, static_argnums_, arguments);
|
||||
if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
|
||||
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||
}
|
||||
|
||||
// The C++ jit do not support Tracers arguments inputs yet. The Python-based
|
||||
// jit function will be called if any of the dynamic arguments is unsupported.
|
||||
|
@ -132,10 +132,10 @@ struct ParsedArgumentsAsBuffers {
|
||||
|
||||
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
||||
// dynamic positional and keyword arguments), filling `arguments` in place.
|
||||
void ParseArguments(const pybind11::args& args,
|
||||
const pybind11::kwargs& py_kwargs,
|
||||
absl::Span<int const> static_argnums,
|
||||
ParsedArgumentsAsBuffers& arguments);
|
||||
Status ParseArguments(const pybind11::args& args,
|
||||
const pybind11::kwargs& py_kwargs,
|
||||
absl::Span<int const> static_argnums,
|
||||
ParsedArgumentsAsBuffers& arguments);
|
||||
|
||||
struct DevicePutResult {
|
||||
explicit DevicePutResult(PjRtBuffer* b, bool weak_type)
|
||||
|
Loading…
x
Reference in New Issue
Block a user