Fallback when faced with other runtimes.
PiperOrigin-RevId: 336173651 Change-Id: Ia585abd06c4f24f77d5b9cabc4487d238b4bb4ac
This commit is contained in:
parent
ebe1d5c6c0
commit
d483dea0c6
@ -280,6 +280,8 @@ class CompiledFunction {
|
||||
py::object out_and_fastpath_data);
|
||||
bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); }
|
||||
|
||||
bool always_fallback_to_python_ = false;
|
||||
|
||||
const py::function fun_; // The Python function to jit.
|
||||
// See JAX _cpp_jit in api.py for documentation.
|
||||
const py::function cache_miss_;
|
||||
@ -762,6 +764,9 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
|
||||
}
|
||||
|
||||
py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
if (always_fallback_to_python_) {
|
||||
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||
}
|
||||
// Delayed values are retrieved on the first call to `Call`.
|
||||
if (!default_device_) {
|
||||
// As we are calling Python code, that may release the GIL, we first hold
|
||||
@ -775,8 +780,14 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
jax_disable_jit_ = py::cast<bool>(get_jax_disable_jit_());
|
||||
if (!default_device_) {
|
||||
py::object device_and_is_committed = get_device_();
|
||||
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
|
||||
device_and_is_committed.attr("default_device"));
|
||||
try {
|
||||
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
|
||||
device_and_is_committed.attr("default_device"));
|
||||
} catch (const py::cast_error& e) {
|
||||
// Pathways and Cloud TPU 2VM runtime.
|
||||
always_fallback_to_python_ = true;
|
||||
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||
}
|
||||
default_pyclient_ = default_pydevice_.client;
|
||||
default_device_ = default_pydevice_.contents;
|
||||
is_committed_ =
|
||||
|
Loading…
Reference in New Issue
Block a user