Correctly raise an error if the Python compilation failed.
There was 2 errors: - capturing the error as std::exception was dropping the error message - We need to check unconditionally (line 741) that no error was raised, in presence of multithreading. PiperOrigin-RevId: 332441152 Change-Id: I0aa663346a0b428050d212de22734b2eb362d488
This commit is contained in:
parent
55acd981d7
commit
ef1567488a
@ -239,7 +239,7 @@ struct CacheEntry {
|
||||
// a signature and if the object has been insterted already, other threads
|
||||
// will wait for the notification.
|
||||
absl::Notification compilation_complete;
|
||||
absl::optional<std::exception> compilation_error = absl::nullopt;
|
||||
absl::optional<Status> compilation_error = absl::nullopt;
|
||||
};
|
||||
|
||||
// A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
|
||||
@ -314,7 +314,7 @@ class CompiledFunction {
|
||||
// absl::optional<absl::Notification> is not supported
|
||||
bool first_compilation_started_ = false;
|
||||
absl::Notification first_compilation_complete_;
|
||||
absl::optional<std::exception> first_compilation_error_ = absl::nullopt;
|
||||
absl::optional<Status> first_compilation_error_ = absl::nullopt;
|
||||
};
|
||||
|
||||
CompiledFunction::CompiledFunction(py::function fun,
|
||||
@ -646,7 +646,8 @@ CacheEntry& CompiledFunction::GetCacheEntry(
|
||||
py::gil_scoped_release gil_release;
|
||||
found_iterator->second->compilation_complete.WaitForNotification();
|
||||
if (found_iterator->second->compilation_error) {
|
||||
throw found_iterator->second->compilation_error.value();
|
||||
throw std::invalid_argument(
|
||||
found_iterator->second->compilation_error.value().error_message());
|
||||
}
|
||||
}
|
||||
return *(found_iterator->second);
|
||||
@ -671,8 +672,8 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry(
|
||||
} else {
|
||||
try {
|
||||
executable_and_pytree = cache_miss_fun_(*args, **kwargs);
|
||||
} catch (const std::exception& e) {
|
||||
cache_entry.compilation_error = e;
|
||||
} catch (const py::error_already_set& e) {
|
||||
cache_entry.compilation_error = InvalidArgument("%s", e.what());
|
||||
cache_entry.compilation_complete.Notify();
|
||||
throw;
|
||||
}
|
||||
@ -736,16 +737,17 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
if (!first_compilation_complete_.HasBeenNotified()) {
|
||||
py::gil_scoped_release gil_release;
|
||||
first_compilation_complete_.WaitForNotification();
|
||||
if (first_compilation_error_) {
|
||||
throw first_compilation_error_.value();
|
||||
}
|
||||
}
|
||||
if (first_compilation_error_) {
|
||||
throw std::invalid_argument(
|
||||
first_compilation_error_.value().error_message());
|
||||
}
|
||||
} else {
|
||||
first_compilation_started_ = true;
|
||||
try {
|
||||
cache_miss_result = cache_miss_fun_(*args, **kwargs);
|
||||
} catch (const std::exception& e) {
|
||||
first_compilation_error_ = e;
|
||||
} catch (const py::error_already_set& e) {
|
||||
first_compilation_error_ = InvalidArgument("%s", e.what());
|
||||
first_compilation_complete_.Notify();
|
||||
throw;
|
||||
}
|
||||
@ -754,9 +756,14 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
|
||||
pyclient_ = executable->client();
|
||||
default_device_ = executable->LocalDevices()[0].contents;
|
||||
if (!default_device_) {
|
||||
throw std::invalid_argument(
|
||||
"executable->LocalDevices()[0] should not be null!");
|
||||
}
|
||||
first_compilation_complete_.Notify();
|
||||
}
|
||||
}
|
||||
CHECK(default_device_);
|
||||
|
||||
// The C++ jit do not support Tracers arguments yet. The Python-based jit
|
||||
// function will be called if any of the dynamic arguments is unsupported.
|
||||
|
Loading…
x
Reference in New Issue
Block a user