diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 2c364573e5b..44c263367af 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -239,7 +239,7 @@ struct CacheEntry { // a signature and if the object has been insterted already, other threads // will wait for the notification. absl::Notification compilation_complete; - absl::optional compilation_error = absl::nullopt; + absl::optional compilation_error = absl::nullopt; }; // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the @@ -314,7 +314,7 @@ class CompiledFunction { // absl::optional is not supported bool first_compilation_started_ = false; absl::Notification first_compilation_complete_; - absl::optional first_compilation_error_ = absl::nullopt; + absl::optional first_compilation_error_ = absl::nullopt; }; CompiledFunction::CompiledFunction(py::function fun, @@ -646,7 +646,8 @@ CacheEntry& CompiledFunction::GetCacheEntry( py::gil_scoped_release gil_release; found_iterator->second->compilation_complete.WaitForNotification(); if (found_iterator->second->compilation_error) { - throw found_iterator->second->compilation_error.value(); + throw std::invalid_argument( + found_iterator->second->compilation_error.value().error_message()); } } return *(found_iterator->second); @@ -671,8 +672,8 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry( } else { try { executable_and_pytree = cache_miss_fun_(*args, **kwargs); - } catch (const std::exception& e) { - cache_entry.compilation_error = e; + } catch (const py::error_already_set& e) { + cache_entry.compilation_error = InvalidArgument("%s", e.what()); cache_entry.compilation_complete.Notify(); throw; } @@ -736,16 +737,17 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { if (!first_compilation_complete_.HasBeenNotified()) { py::gil_scoped_release gil_release; first_compilation_complete_.WaitForNotification(); - if (first_compilation_error_) { - throw first_compilation_error_.value(); - } + } + if (first_compilation_error_) { + throw std::invalid_argument( + first_compilation_error_.value().error_message()); } } else { first_compilation_started_ = true; try { cache_miss_result = cache_miss_fun_(*args, **kwargs); - } catch (const std::exception& e) { - first_compilation_error_ = e; + } catch (const py::error_already_set& e) { + first_compilation_error_ = InvalidArgument("%s", e.what()); first_compilation_complete_.Notify(); throw; } @@ -754,9 +756,14 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { pyclient_ = executable->client(); default_device_ = executable->LocalDevices()[0].contents; + if (!default_device_) { + throw std::invalid_argument( + "executable->LocalDevices()[0] should not be null!"); + } first_compilation_complete_.Notify(); } } + CHECK(default_device_); // The C++ jit do not support Tracers arguments yet. The Python-based jit // function will be called if any of the dynamic arguments is unsupported.