Have the C++ path raise a cleaner error.

PiperOrigin-RevId: 347818682
Change-Id: I6a02aea7843761f229325fcb1b0d2078af242103
This commit is contained in:
Jean-Baptiste Lespiau 2020-12-16 07:32:23 -08:00 committed by TensorFlower Gardener
parent 26f4c529e7
commit 76e17494d6
2 changed files with 15 additions and 8 deletions

View File

@ -178,9 +178,13 @@ H AbslHashValue(H h, const CallSignature& s) {
// Filter out static arguments, flatten and concatenate other arguments (i.e.
// dynamic positional and keyword arguments), filling `arguments` in place.
void ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments) {
Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments) {
if (static_argnums.size() > args.size()) {
return InvalidArgument(
"%s", "[jaxjit] Error with static argnums, executing the Python path.");
}
arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
static_argnums.size());
arguments.signature.dynamic_positional_args_treedef.reserve(
@ -230,6 +234,7 @@ void ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
arguments.signature.keyword_args[i].value_treedef.FlattenInto(
kwargs[i].second, arguments.flat_dynamic_args);
}
return Status::OK();
}
namespace {
@ -764,7 +769,9 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
return fun_(*args, **kwargs);
}
ParsedArgumentsAsBuffers arguments;
ParseArguments(args, kwargs, static_argnums_, arguments);
if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
// The C++ jit do not support Tracers arguments inputs yet. The Python-based
// jit function will be called if any of the dynamic arguments is unsupported.

View File

@ -132,10 +132,10 @@ struct ParsedArgumentsAsBuffers {
// Filter out static arguments, flatten and concatenate other arguments (i.e.
// dynamic positional and keyword arguments), filling `arguments` in place.
void ParseArguments(const pybind11::args& args,
const pybind11::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments);
Status ParseArguments(const pybind11::args& args,
const pybind11::kwargs& py_kwargs,
absl::Span<int const> static_argnums,
ParsedArgumentsAsBuffers& arguments);
struct DevicePutResult {
explicit DevicePutResult(PjRtBuffer* b, bool weak_type)