diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index a8d2030099c..f3c5e9ccd0a 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -1097,6 +1097,9 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args, } py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { + if (JitIsDisabled()) { + return fun_(*args, **kwargs); + } if (always_fallback_to_python_) { return py::cast(cache_miss_(*args, **kwargs))[0]; } @@ -1131,9 +1134,7 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { } } CHECK(default_device_); - if (JitIsDisabled()) { - return fun_(*args, **kwargs); - } + ParsedArgumentsAsBuffers arguments; if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) { return py::cast(cache_miss_(*args, **kwargs))[0];