Correctly raise an error if the Python compilation failed.

There was 2 errors:
- capturing the error as std::exception was dropping the error message
- We need to check unconditionally (line 741) that no error was raised, in presence of multithreading.

PiperOrigin-RevId: 332441152
Change-Id: I0aa663346a0b428050d212de22734b2eb362d488
This commit is contained in:
A. Unique TensorFlower 2020-09-18 06:44:41 -07:00 committed by TensorFlower Gardener
parent 55acd981d7
commit ef1567488a

View File

@ -239,7 +239,7 @@ struct CacheEntry {
// a signature and if the object has been insterted already, other threads
// will wait for the notification.
absl::Notification compilation_complete;
absl::optional<std::exception> compilation_error = absl::nullopt;
absl::optional<Status> compilation_error = absl::nullopt;
};
// A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
@ -314,7 +314,7 @@ class CompiledFunction {
// absl::optional<absl::Notification> is not supported
bool first_compilation_started_ = false;
absl::Notification first_compilation_complete_;
absl::optional<std::exception> first_compilation_error_ = absl::nullopt;
absl::optional<Status> first_compilation_error_ = absl::nullopt;
};
CompiledFunction::CompiledFunction(py::function fun,
@ -646,7 +646,8 @@ CacheEntry& CompiledFunction::GetCacheEntry(
py::gil_scoped_release gil_release;
found_iterator->second->compilation_complete.WaitForNotification();
if (found_iterator->second->compilation_error) {
throw found_iterator->second->compilation_error.value();
throw std::invalid_argument(
found_iterator->second->compilation_error.value().error_message());
}
}
return *(found_iterator->second);
@ -671,8 +672,8 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry(
} else {
try {
executable_and_pytree = cache_miss_fun_(*args, **kwargs);
} catch (const std::exception& e) {
cache_entry.compilation_error = e;
} catch (const py::error_already_set& e) {
cache_entry.compilation_error = InvalidArgument("%s", e.what());
cache_entry.compilation_complete.Notify();
throw;
}
@ -736,16 +737,17 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
if (!first_compilation_complete_.HasBeenNotified()) {
py::gil_scoped_release gil_release;
first_compilation_complete_.WaitForNotification();
if (first_compilation_error_) {
throw first_compilation_error_.value();
}
}
if (first_compilation_error_) {
throw std::invalid_argument(
first_compilation_error_.value().error_message());
}
} else {
first_compilation_started_ = true;
try {
cache_miss_result = cache_miss_fun_(*args, **kwargs);
} catch (const std::exception& e) {
first_compilation_error_ = e;
} catch (const py::error_already_set& e) {
first_compilation_error_ = InvalidArgument("%s", e.what());
first_compilation_complete_.Notify();
throw;
}
@ -754,9 +756,14 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
pyclient_ = executable->client();
default_device_ = executable->LocalDevices()[0].contents;
if (!default_device_) {
throw std::invalid_argument(
"executable->LocalDevices()[0] should not be null!");
}
first_compilation_complete_.Notify();
}
}
CHECK(default_device_);
// The C++ jit do not support Tracers arguments yet. The Python-based jit
// function will be called if any of the dynamic arguments is unsupported.