Have the C++ path raise a cleaner error.
PiperOrigin-RevId: 347818682 Change-Id: I6a02aea7843761f229325fcb1b0d2078af242103
This commit is contained in:
parent
26f4c529e7
commit
76e17494d6
@ -178,9 +178,13 @@ H AbslHashValue(H h, const CallSignature& s) {
|
|||||||
|
|
||||||
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
||||||
// dynamic positional and keyword arguments), filling `arguments` in place.
|
// dynamic positional and keyword arguments), filling `arguments` in place.
|
||||||
void ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||||
absl::Span<int const> static_argnums,
|
absl::Span<int const> static_argnums,
|
||||||
ParsedArgumentsAsBuffers& arguments) {
|
ParsedArgumentsAsBuffers& arguments) {
|
||||||
|
if (static_argnums.size() > args.size()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"%s", "[jaxjit] Error with static argnums, executing the Python path.");
|
||||||
|
}
|
||||||
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
|
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
|
||||||
static_argnums.size());
|
static_argnums.size());
|
||||||
arguments.signature.dynamic_positional_args_treedef.reserve(
|
arguments.signature.dynamic_positional_args_treedef.reserve(
|
||||||
@ -230,6 +234,7 @@ void ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
|
|||||||
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
|
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
|
||||||
kwargs[i].second, arguments.flat_dynamic_args);
|
kwargs[i].second, arguments.flat_dynamic_args);
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -764,7 +769,9 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
|||||||
return fun_(*args, **kwargs);
|
return fun_(*args, **kwargs);
|
||||||
}
|
}
|
||||||
ParsedArgumentsAsBuffers arguments;
|
ParsedArgumentsAsBuffers arguments;
|
||||||
ParseArguments(args, kwargs, static_argnums_, arguments);
|
if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
|
||||||
|
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||||
|
}
|
||||||
|
|
||||||
// The C++ jit do not support Tracers arguments inputs yet. The Python-based
|
// The C++ jit do not support Tracers arguments inputs yet. The Python-based
|
||||||
// jit function will be called if any of the dynamic arguments is unsupported.
|
// jit function will be called if any of the dynamic arguments is unsupported.
|
||||||
|
@ -132,10 +132,10 @@ struct ParsedArgumentsAsBuffers {
|
|||||||
|
|
||||||
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
// Filter out static arguments, flatten and concatenate other arguments (i.e.
|
||||||
// dynamic positional and keyword arguments), filling `arguments` in place.
|
// dynamic positional and keyword arguments), filling `arguments` in place.
|
||||||
void ParseArguments(const pybind11::args& args,
|
Status ParseArguments(const pybind11::args& args,
|
||||||
const pybind11::kwargs& py_kwargs,
|
const pybind11::kwargs& py_kwargs,
|
||||||
absl::Span<int const> static_argnums,
|
absl::Span<int const> static_argnums,
|
||||||
ParsedArgumentsAsBuffers& arguments);
|
ParsedArgumentsAsBuffers& arguments);
|
||||||
|
|
||||||
struct DevicePutResult {
|
struct DevicePutResult {
|
||||||
explicit DevicePutResult(PjRtBuffer* b, bool weak_type)
|
explicit DevicePutResult(PjRtBuffer* b, bool weak_type)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user