From 76e17494d6fc61af892703bc52a4ace5586dc562 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Wed, 16 Dec 2020 07:32:23 -0800 Subject: [PATCH] Have the C++ path raise a cleaner error. PiperOrigin-RevId: 347818682 Change-Id: I6a02aea7843761f229325fcb1b0d2078af242103 --- tensorflow/compiler/xla/python/jax_jit.cc | 15 +++++++++++---- tensorflow/compiler/xla/python/jax_jit.h | 8 ++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index c5a4fa53ab7..0c6d207841e 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -178,9 +178,13 @@ H AbslHashValue(H h, const CallSignature& s) { // Filter out static arguments, flatten and concatenate other arguments (i.e. // dynamic positional and keyword arguments), filling `arguments` in place. -void ParseArguments(const py::args& args, const py::kwargs& py_kwargs, - absl::Span static_argnums, - ParsedArgumentsAsBuffers& arguments) { +Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs, + absl::Span static_argnums, + ParsedArgumentsAsBuffers& arguments) { + if (static_argnums.size() > args.size()) { + return InvalidArgument( + "%s", "[jaxjit] Error with static argnums, executing the Python path."); + } arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() - static_argnums.size()); arguments.signature.dynamic_positional_args_treedef.reserve( @@ -230,6 +234,7 @@ void ParseArguments(const py::args& args, const py::kwargs& py_kwargs, arguments.signature.keyword_args[i].value_treedef.FlattenInto( kwargs[i].second, arguments.flat_dynamic_args); } + return Status::OK(); } namespace { @@ -764,7 +769,9 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { return fun_(*args, **kwargs); } ParsedArgumentsAsBuffers arguments; - ParseArguments(args, kwargs, static_argnums_, arguments); + if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) { + return py::cast(cache_miss_(*args, **kwargs))[0]; + } // The C++ jit do not support Tracers arguments inputs yet. The Python-based // jit function will be called if any of the dynamic arguments is unsupported. diff --git a/tensorflow/compiler/xla/python/jax_jit.h b/tensorflow/compiler/xla/python/jax_jit.h index 11855a668c2..08ab7c8c018 100644 --- a/tensorflow/compiler/xla/python/jax_jit.h +++ b/tensorflow/compiler/xla/python/jax_jit.h @@ -132,10 +132,10 @@ struct ParsedArgumentsAsBuffers { // Filter out static arguments, flatten and concatenate other arguments (i.e. // dynamic positional and keyword arguments), filling `arguments` in place. -void ParseArguments(const pybind11::args& args, - const pybind11::kwargs& py_kwargs, - absl::Span static_argnums, - ParsedArgumentsAsBuffers& arguments); +Status ParseArguments(const pybind11::args& args, + const pybind11::kwargs& py_kwargs, + absl::Span static_argnums, + ParsedArgumentsAsBuffers& arguments); struct DevicePutResult { explicit DevicePutResult(PjRtBuffer* b, bool weak_type)